jacobmahon's picture
Add inference script for vulnerability scanning
fd6d0c6 verified
"""
Zero-Day Exploit Scanner & Fixer - Inference Script
====================================================
Scans code for vulnerabilities and generates fixes.
Usage:
python inference.py --code "int main() { char buf[10]; gets(buf); }"
python inference.py --file vulnerable_code.c
python inference.py --interactive
Requirements:
pip install transformers peft torch bitsandbytes accelerate
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
MODEL_ID = "jacobmahon/zero-day-exploit-scanner-fixer"
BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
SYSTEM_PROMPT = """You are a world-class security expert specializing in zero-day vulnerability detection and remediation. When given code, you will:
1. SCAN: Determine if the code contains a security vulnerability
2. IDENTIFY: If vulnerable, identify the CWE type and CVE ID if known
3. EXPLAIN: Provide a clear explanation of the vulnerability mechanism, attack vector, and potential impact
4. FIX: Provide the corrected code that patches the vulnerability
Always respond in the following structured format:
## SCAN RESULT
[VULNERABLE / SAFE]
## VULNERABILITY DETAILS
- **CWE**: [CWE ID and name]
- **CVE**: [CVE ID if known, otherwise "N/A"]
- **Severity**: [CRITICAL / HIGH / MEDIUM / LOW]
## EXPLANATION
[Detailed explanation of the vulnerability]
## VULNERABLE LINES
[Specific lines or patterns that are vulnerable]
## FIXED CODE
```
[Corrected code]
```
## FIX EXPLANATION
[What was changed and why]"""
def load_model(model_id=MODEL_ID, base_model=BASE_MODEL, device="auto"):
"""Load the fine-tuned model with QLoRA adapter."""
print(f"Loading base model: {base_model}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map=device,
torch_dtype=torch.bfloat16,
)
print(f"Loading LoRA adapter: {model_id}")
model = PeftModel.from_pretrained(model, model_id)
model.eval()
return model, tokenizer
def scan_code(code: str, model, tokenizer, language: str = "auto", max_new_tokens: int = 2048):
"""Scan code for vulnerabilities and generate fixes."""
if language == "auto":
# Simple language detection heuristics
if "#include" in code or "malloc" in code or "void " in code:
language = "C"
elif "def " in code or "import " in code:
language = "Python"
elif "function " in code or "const " in code or "=>" in code:
language = "JavaScript"
elif "public class" in code or "System.out" in code:
language = "Java"
else:
language = "code"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Analyze the following {language} code for security vulnerabilities and provide a fix if needed:\n\n```{language.lower()}\n{code}\n```"},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.3,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return response
def main():
parser = argparse.ArgumentParser(description="Zero-Day Exploit Scanner & Fixer")
parser.add_argument("--code", type=str, help="Code string to analyze")
parser.add_argument("--file", type=str, help="Path to file to analyze")
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
parser.add_argument("--language", type=str, default="auto", help="Programming language")
parser.add_argument("--model", type=str, default=MODEL_ID, help="Model ID")
parser.add_argument("--base-model", type=str, default=BASE_MODEL, help="Base model ID")
args = parser.parse_args()
model, tokenizer = load_model(args.model, args.base_model)
if args.code:
result = scan_code(args.code, model, tokenizer, args.language)
print(result)
elif args.file:
with open(args.file, "r") as f:
code = f.read()
print(f"\nScanning: {args.file}")
print("=" * 60)
result = scan_code(code, model, tokenizer, args.language)
print(result)
elif args.interactive:
print("Zero-Day Exploit Scanner & Fixer")
print("Enter code to analyze (type 'END' on a new line to submit, 'quit' to exit)")
print("=" * 60)
while True:
print("\nEnter code:")
lines = []
while True:
line = input()
if line.strip() == "END":
break
if line.strip() == "quit":
return
lines.append(line)
code = "\n".join(lines)
if not code.strip():
continue
print("\nAnalyzing...")
result = scan_code(code, model, tokenizer, args.language)
print("\n" + result)
else:
# Demo with example vulnerable code
demo_code = '''
void process_input(char *user_input) {
char buffer[64];
strcpy(buffer, user_input); // No bounds checking
printf("Processed: %s\\n", buffer);
}
int main() {
char input[1024];
gets(input); // Unsafe input
process_input(input);
return 0;
}
'''
print("Demo: Scanning example vulnerable C code")
print("=" * 60)
result = scan_code(demo_code, model, tokenizer, "C")
print(result)
if __name__ == "__main__":
main()