import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import transformers
from openai import OpenAI
import re
import base64
import anthropic
import wandb
import time


TYPE_OF_VUL = [22, 78, 79, 89, 125, 190, 416, 476, 787]

Prompts_OF_VUL = [
            'CWE-22, commonly called “Path Traversal,” is a vulnerability when an application fails to appropriately limit the paths users can access through a user-provided input.',
            'CWE-78, means improper neutralization of special elements used in an OS command (OS command injection), constructs all or part of an OS command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended OS command when it is sent to a downstream component.',
            'CWE-79, improper neutralization of input during web page generation (cross-site scripting). The code does not neutralize or incorrectly neutralizes user-controllable input before it is placed in output that is used as a web page that is served to other users.',
            'CWE-89: Improper Neutralization of Special Elements used in an SQL Command (SQL Injection). The product constructs all or part of an SQL command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended SQL command when it is sent to a downstream component. Without sufficient removal or quoting of SQL syntax in user-controllable inputs, the generated SQL query can cause those inputs to be interpreted as SQL instead of ordinary user data.',
            'CWE-125: Out-of-bounds Read. The product reads data past the end, or before the beginning, of the intended buffer.',
            'CWE-190: Integer Overflow or Wraparound. The product performs a calculation that can produce an integer overflow or wraparound when the logic assumes that the resulting value will always be larger than the original value. This occurs when an integer value is incremented to a value that is too large to store in the associated representation. When this occurs, the value may become a very small or negative number.',
            'CWE-416: Use After Free. The product reuses or references memory after it has been freed. At some point afterward, the memory may be allocated again and saved in another pointer, while the original pointer references a location somewhere within the new allocation. Any operations using the original pointer are no longer valid because the memory "belongs" to the code that operates on the new pointer.',
            'CWE-476: NULL Pointer Dereference. The product dereferences a pointer that it expects to be valid but is NULL.',
            'CWE-787: Out-of-bounds Write. The product writes data past the end, or before the beginning, of the intended buffer.',
        ]

def get_second_to_last_layer_mean(model, tokenizer, dataset, keyword, vul_prompt):
    # Tokenize input

    second_to_last_layer_list = []
    for index in range(len(dataset)):
        if index == 0 or index == 1:
            pass
        single_data = dataset[index]
        # input_text =  single_data[keyword] + 'Is the following code vulnarable? Answer with yes or no.'
        input_text =  f"{vul_prompt} For example, {dataset[0]['func_src_before']} is vulnerable while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data[keyword]}'''. Answer with simply yes or no."
        inputs = tokenizer(input_text, return_tensors="pt")
        
        # Forward pass with output of all hidden states
        with torch.no_grad():
            outputs = model(
                input_ids=inputs['input_ids'].to(model.device), 
                output_hidden_states=True
            )
        
        # Get second to last layer's hidden states
        last_layer = outputs.hidden_states[-1][:, -1]
        
        second_to_last_layer = outputs.hidden_states[-2][:, -1]
        prob = outputs[0][:,-1,:]
        ## sanity check:
        
        # print(second_to_last_layer.shape)
        second_to_last_layer_list.append(second_to_last_layer)
    
    second_to_last_layer_list = torch.cat(second_to_last_layer_list, dim=0)
    print(second_to_last_layer_list.shape)
    return second_to_last_layer_list


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

class FeatureExtractor(nn.Module):
    def __init__(self, model, tokenizer, device='cuda:0'):
        super().__init__()
        # Remove the last classification layer
        self.model = model
        model.eval()
        self.tokenizer = tokenizer
        self.device = device
        
    def forward(self, dataset, vul_prompt):
        features_before = get_second_to_last_layer_mean(self.model, self.tokenizer, dataset, 'func_src_before', vul_prompt)
        features_after = get_second_to_last_layer_mean(self.model, self.tokenizer, dataset, 'func_src_after', vul_prompt)

        return features_before.to(self.device), features_after.to(self.device)

class LinearProbe(nn.Module):
    def __init__(self, input_dim=4096, num_classes=2):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

class MLPProbe(nn.Module):
    def __init__(self, input_dim=4096, hidden_dims=[512, 256], num_classes=2, dropout=0.1):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # 构建隐藏层
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        # 输出层
        layers.append(nn.Linear(prev_dim, num_classes))
        
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)
    
class ProbingTrainer:
    def __init__(
        self,
        probe_type,
        num_classes,
        feature_dim,
        learning_rate=0.1,
        half_precision=True,
        device='cuda:0' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.half_precision = half_precision
        # Initialize feature extractor with frozen weights
     
        # for param in self.feature_extractor.parameters():
        #     param.requires_grad = False
        
        # Initialize linear probe
        if probe_type == 'linear':
            self.probe = LinearProbe(feature_dim, num_classes).to(device)
        elif probe_type == 'mlp':
            self.probe = MLPProbe(input_dim=feature_dim, num_classes=num_classes).to(device)
        if self.half_precision:
            self.probe.half()
        # Setup optimizer and loss
        self.optimizer = optim.SGD(self.probe.parameters(),  
                            lr=learning_rate, momentum=0.9, weight_decay=1e-4)  #
        self.criterion = nn.CrossEntropyLoss()
        
        
    def train_epoch(self, features_before_train, features_after_train):
        self.probe.train()
        total_loss = 0
        total = 0
        
        
        labels_before = torch.zeros(features_before_train.shape[0], dtype=torch.long).to(self.device)
        labels_after = torch.ones(features_after_train.shape[0], dtype=torch.long).to(self.device)
        # Forward pass through linear probe
        self.optimizer.zero_grad()
        outputs_before = self.probe(features_before_train)
        outputs_after = self.probe(features_after_train)
        # print(outputs_before, outputs_after)
        # Calculate loss and backpropagate
        loss = self.criterion(outputs_before, labels_before) + self.criterion(outputs_after, labels_after)
        loss.backward()
        self.optimizer.step()
            
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted_before = outputs_before.max(1)
        _, predicted_after = outputs_after.max(1)
        total += labels_before.size(0) + labels_after.size(0)
        correct_vul = predicted_before.eq(labels_before).sum().item()
        correct_safe = predicted_after.eq(labels_after).sum().item()
        # correct = correct_vul + correct_safe
            
        return total_loss / features_before_train.shape[0], 100. * correct_vul / total *2, 100. * correct_safe / total *2
    
    def evaluate(self, features_before_eval, features_after_eval):
        self.probe.eval()
        correct_vul = 0
        correct_safe = 0
        total = 0
        
        with torch.no_grad():

            labels_before = torch.zeros(features_before_eval.shape[0]).to(self.device)
            labels_after = torch.ones(features_after_eval.shape[0]).to(self.device)
        
            outputs_before = self.probe(features_before_eval)
            outputs_after = self.probe(features_after_eval)
            
            _, predicted_before = outputs_before.max(1)
            _, predicted_after = outputs_after.max(1)
            
            total += labels_before.size(0) + labels_after.size(0)
            correct_vul += predicted_before.eq(labels_before).sum().item()
            correct_safe += predicted_after.eq(labels_after).sum().item()
                
        return 100. * correct_vul / total *2, 100. * correct_safe / total *2

    def train(self, features, num_epochs):
        best_acc = 0
        stats = []
        
        features_before_train = features[0]
        features_after_train = features[1]
        features_before_eval = features[2]
        features_after_eval = features[3]
        
        for epoch in range(num_epochs):
            # Train one epoch
            train_loss, train_acc_vul, train_acc_safe = self.train_epoch(features_before_train, features_after_train)
            
            # Evaluate
            val_acc_vul, val_acc_safe = self.evaluate(features_before_eval, features_after_eval)
            # Save statistics
            stats.append({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_acc_vul': train_acc_vul,
                'train_acc_safe': train_acc_safe,
                'val_acc_vul': val_acc_vul,
                'val_acc_safe': val_acc_safe
            })
            wandb.log({"loss": train_loss, "acc_vul": train_acc_vul, "acc_safe": train_acc_safe, "val_acc_vul": val_acc_vul, "val_acc_safe": val_acc_safe})
            
            # Save best model
            # if val_acc > best_acc:
            #     best_acc = val_acc
            #     torch.save(self.linear_probe.state_dict(), 'best_linear_probe.pth')
            
            print(f'Epoch: {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f} | Train Acc Vul: {train_acc_vul:.2f}% | Train Acc Safe: {train_acc_safe:.2f}%')
            print(f'Val Acc Vul: {val_acc_vul:.2f}% | Val Acc Safe: {val_acc_safe:.2f}%')
            print('--------------------')
        
        
        return stats



def get_dataset(jsonl_path):
    # construct the dataset from the jsonl file
    dataset = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            _ = json.loads(line)
            dataset.append(_)
    print('dataset size',len(dataset))
    
    return dataset



def main():
    # Load Llama 3.1 model
    device = 'cuda:1'
    
    # model_path = "meta-llama/Llama-3.1-8B-Instruct"
    model_path = 'Qwen/Qwen2.5-Coder-7B-Instruct'
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=torch.float16, 
        # device_map="auto"
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    

    # Train
    feature_extractor = FeatureExtractor(model, tokenizer).to(device)
    
    train_stats = []
    
    for index in range(1):
        vul = TYPE_OF_VUL[index]
        vul_prompt = Prompts_OF_VUL[index]
        jsonl_path = f'./sven/data_train_val/train/cwe-{vul:03d}.jsonl'
        dataset = get_dataset(jsonl_path)
        
        features_before, features_after = feature_extractor(dataset, vul_prompt)
        
        # minus the center vector
        ave_features = 0.5*(torch.mean(features_before, dim=0, keepdim=True) + torch.mean(features_after, dim=0, keepdim=True))
        features_before = features_before - ave_features
        features_after = features_after - ave_features
        
        features_before_train  = features_before[:-10]
        features_after_train = features_after[:-10]
        features_before_eval  = features_before[-10:]
        features_after_eval = features_after[-10:]
        features = [features_before_train, features_after_train, features_before_eval, features_after_eval]
    
        trainer = ProbingTrainer(
            probe_type='mlp',
            num_classes=2,  # Number of classes in your dataset
            feature_dim=features_before_train.shape[-1],  # Feature dimension from your backbone 3584 for qwen2.5, 4096 for llama3
            learning_rate=0.001
        )
        wandb.init(project="linear_probing", name=f"{time.time()}{model_path}_mlp_CWE-{vul}_center", config={"model": model_path})
        stats = trainer.train(features, num_epochs=5000)
        train_stats.append(stats)
    ave_train_acc = 0
    ave_eval_acc = 0
    for stats in train_stats:
        ave_train_acc += stats[-1]['train_acc_vul']
        ave_eval_acc += stats[-1]['val_acc_vul']
    ave_train_acc /= len(train_stats)
    ave_eval_acc /= len(train_stats)
    wandb.log({"ave_train_acc": ave_train_acc, "ave_eval_acc": ave_eval_acc})
    wandb.finish()

if __name__ == "__main__":
    main()