import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import transformers
import wandb
import time
import torch.nn as nn
import torch.optim as optim

TYPE_OF_VUL = [22, 78, 79, 89, 125, 190, 416, 476, 787]

Prompts_OF_VUL = [
            'CWE-22, commonly called “Path Traversal,” is a vulnerability when an application fails to appropriately limit the paths users can access through a user-provided input.',
            'CWE-78, means improper neutralization of special elements used in an OS command (OS command injection), constructs all or part of an OS command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended OS command when it is sent to a downstream component.',
            'CWE-79, improper neutralization of input during web page generation (cross-site scripting). The code does not neutralize or incorrectly neutralizes user-controllable input before it is placed in output that is used as a web page that is served to other users.',
            'CWE-89: Improper Neutralization of Special Elements used in an SQL Command (SQL Injection). The product constructs all or part of an SQL command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended SQL command when it is sent to a downstream component. Without sufficient removal or quoting of SQL syntax in user-controllable inputs, the generated SQL query can cause those inputs to be interpreted as SQL instead of ordinary user data.',
            'CWE-125: Out-of-bounds Read. The product reads data past the end, or before the beginning, of the intended buffer.',
            'CWE-190: Integer Overflow or Wraparound. The product performs a calculation that can produce an integer overflow or wraparound when the logic assumes that the resulting value will always be larger than the original value. This occurs when an integer value is incremented to a value that is too large to store in the associated representation. When this occurs, the value may become a very small or negative number.',
            'CWE-416: Use After Free. The product reuses or references memory after it has been freed. At some point afterward, the memory may be allocated again and saved in another pointer, while the original pointer references a location somewhere within the new allocation. Any operations using the original pointer are no longer valid because the memory "belongs" to the code that operates on the new pointer.',
            'CWE-476: NULL Pointer Dereference. The product dereferences a pointer that it expects to be valid but is NULL.',
            'CWE-787: Out-of-bounds Write. The product writes data past the end, or before the beginning, of the intended buffer.',
        ]

def get_second_to_last_layer_mean(model, tokenizer, dataset, keyword, vul_prompt):
    # Tokenize input

    second_to_last_layer_list = []
    for index in range(len(dataset)):
        if index == 0 or index == 1:
            pass
        single_data = dataset[index]
        # input_text =  single_data[keyword] + 'Is the following code vulnarable? Answer with yes or no.'
        input_text =  f"{vul_prompt} For example, {dataset[0]['func_src_before']} is vulnerable while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data[keyword]}'''. Answer with simply yes or no."
        inputs = tokenizer(input_text, return_tensors="pt")
        
        # Forward pass with output of all hidden states
        with torch.no_grad():
            outputs = model(
                input_ids=inputs['input_ids'].to(model.device), 
                output_hidden_states=True
            )
        
        # Get second to last layer's hidden states
        last_layer = outputs.hidden_states[-1][:, -1]
        
        second_to_last_layer = outputs.hidden_states[-2][:, -1]
        prob = outputs[0][:,-1,:]
        ## sanity check:
        
        # print(second_to_last_layer.shape)
        second_to_last_layer_list.append(second_to_last_layer)
    
    second_to_last_layer_list = torch.cat(second_to_last_layer_list, dim=0)
    print(second_to_last_layer_list.shape)
    return second_to_last_layer_list





class FeatureExtractor(nn.Module):
    def __init__(self, model, tokenizer, device='cuda:0'):
        super().__init__()
        # Remove the last classification layer
        self.model = model
        model.eval()
        self.tokenizer = tokenizer
        self.device = device
        
    def forward(self, dataset, vul_prompt):
        features_before = get_second_to_last_layer_mean(self.model, self.tokenizer, dataset, 'func_src_before', vul_prompt)
        features_after = get_second_to_last_layer_mean(self.model, self.tokenizer, dataset, 'func_src_after', vul_prompt)

        return features_before.to(self.device), features_after.to(self.device)

class LinearProbe(nn.Module):
    def __init__(self, input_dim=4096, num_classes=2):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

class MLPProbe(nn.Module):
    def __init__(self, input_dim=4096, hidden_dims=[512, 256], num_classes=2, dropout=0.1):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # 构建隐藏层
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        # 输出层
        layers.append(nn.Linear(prev_dim, num_classes))
        
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)
    
class ProbingTrainer:
    def __init__(
        self,
        probe_type,
        num_classes,
        feature_dim,
        learning_rate=0.1,
        device='cuda:0' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        
        # Initialize feature extractor with frozen weights
     
        # for param in self.feature_extractor.parameters():
        #     param.requires_grad = False
        
        # Initialize linear probe
        if probe_type == 'linear':
            self.probe = LinearProbe(feature_dim, num_classes).to(device)
        elif probe_type == 'mlp':
            self.probe = MLPProbe(input_dim=feature_dim, num_classes=num_classes).to(device)
        
        # Setup optimizer and loss
        self.optimizer = optim.SGD(self.probe.parameters(),  
                            lr=learning_rate, momentum=0.9, weight_decay=1e-4)  #
        self.criterion = nn.CrossEntropyLoss()
        
        
    def train_epoch(self, features_before_train, features_after_train):
        self.probe.train()
        total_loss = 0
        total = 0
        
        
        labels_before = torch.zeros(features_before_train.shape[0], dtype=torch.long).to(self.device)
        labels_after = torch.ones(features_after_train.shape[0], dtype=torch.long).to(self.device)
        # Forward pass through linear probe
        self.optimizer.zero_grad()
        features_before_train = features_before_train.to(torch.float32)
        features_after_train = features_after_train.to(torch.float32)
        outputs_before = self.probe(features_before_train)
        outputs_after = self.probe(features_after_train)
        # print(outputs_before, outputs_after)
        # Calculate loss and backpropagate
        loss = self.criterion(outputs_before, labels_before) + self.criterion(outputs_after, labels_after)
        loss.backward()
        self.optimizer.step()
            
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted_before = outputs_before.max(1)
        _, predicted_after = outputs_after.max(1)
        total += labels_before.size(0) + labels_after.size(0)
        correct_vul = predicted_before.eq(labels_before).sum().item()
        correct_safe = predicted_after.eq(labels_after).sum().item()
        # correct = correct_vul + correct_safe
            
        return total_loss / features_before_train.shape[0], 100. * correct_vul / total *2, 100. * correct_safe / total *2
    
    def evaluate(self, features_before_eval, features_after_eval):
        self.probe.eval()
        correct_vul = 0
        correct_safe = 0
        total = 0
        
        with torch.no_grad():

            labels_before = torch.zeros(features_before_eval.shape[0]).to(self.device)
            labels_after = torch.ones(features_after_eval.shape[0]).to(self.device)
        
            outputs_before = self.probe(features_before_eval)
            outputs_after = self.probe(features_after_eval)
            
            _, predicted_before = outputs_before.max(1)
            _, predicted_after = outputs_after.max(1)
            
            total += labels_before.size(0) + labels_after.size(0)
            correct_vul += predicted_before.eq(labels_before).sum().item()
            correct_safe += predicted_after.eq(labels_after).sum().item()
                
        return 100. * correct_vul / total *2, 100. * correct_safe / total *2

    def train(self, features, num_epochs):
        best_acc = 0
        stats = []
        
        features_before_train = features[0]
        features_after_train = features[1]
        features_before_eval = features[2]
        features_after_eval = features[3]
        
        for epoch in range(num_epochs):
            # Train one epoch
            train_loss, train_acc_vul, train_acc_safe = self.train_epoch(features_before_train, features_after_train)
            
            # Evaluate
            val_acc_vul, val_acc_safe = self.evaluate(features_before_eval, features_after_eval)
            # Save statistics
            stats.append({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_acc_vul': train_acc_vul,
                'train_acc_safe': train_acc_safe,
                'val_acc_vul': val_acc_vul,
                'val_acc_safe': val_acc_safe
            })
            wandb.log({"loss": train_loss, "acc_vul": train_acc_vul, "acc_safe": train_acc_safe, "val_acc_vul": val_acc_vul, "val_acc_safe": val_acc_safe})
            
            # Save best model
            # if val_acc > best_acc:
            #     best_acc = val_acc
            #     torch.save(self.linear_probe.state_dict(), 'best_linear_probe.pth')
            
            print(f'Epoch: {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f} | Train Acc Vul: {train_acc_vul:.2f}% | Train Acc Safe: {train_acc_safe:.2f}%')
            print(f'Val Acc Vul: {val_acc_vul:.2f}% | Val Acc Safe: {val_acc_safe:.2f}%')
            print('--------------------')
        
        
        return stats



def get_dataset(jsonl_path):
    # construct the dataset from the jsonl file
    dataset = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            _ = json.loads(line)
            dataset.append(_)
    print('dataset size',len(dataset))
    
    return dataset




def steer_generate(steer_vec, ori_activation, alpha=1.0):
    """
    Modify the original activation by adding a steering vector
    
    Args:
        steer_vec: The steering vector to add
        ori_activation: Original activation from the model
        alpha: Scaling factor for the steering vector
    """
    # Ensure the vectors are on the same device and have the same shape
    if steer_vec.device != ori_activation.device:
        steer_vec = steer_vec.to(ori_activation.device)
    
    # Add the scaled steering vector to the original activation
    modified_activation = ori_activation + alpha * steer_vec
    return modified_activation

def generate_with_modified_activation(model, tokenizer, prompt, steer_vec, layer_idx=-4, alpha=100.0, 
                                    max_length=1000, temperature=0.7):
    """
    Generate text using modified activations at a specific layer
    
    Args:
        model: The language model
        tokenizer: The tokenizer
        prompt: Input prompt text
        steer_vec: Steering vector to modify activations
        layer_idx: Index of the layer to modify (-2 for second to last)
        alpha: Scaling factor for the steering vector
        max_length: Maximum length of generated text
        temperature: Sampling temperature
    """
    # Tokenize input
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Store original forward hooks
    original_forward = {}
    for name, module in model.named_modules():
        if hasattr(module, 'forward'):
            original_forward[name] = module.forward

    # Define hook to modify activations
    def modify_activation_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output
        # check the shape of the hidden_states
        if hidden_states.shape[-1] == steer_vec.shape[-1]:
            modified = steer_generate(steer_vec, hidden_states, alpha)
        else:
            modified = hidden_states
        
        if isinstance(output, tuple):
            return (modified,) + output[1:]
        return modified

    # Add hook to the target layer
    # for qwen2.5, the layer_idx is -1, or -12:-6
    # import pdb; pdb.set_trace()
    # target_layer = list(model.modules())[layer_idx]
    target_layer = model.model.layers[layer_idx]
    
    handle = target_layer.register_forward_hook(modify_activation_hook)
    
    # try:
    # Generate with modified activations
    with torch.no_grad():
        outputs = model.generate(            
                                 inputs.input_ids,            
                                 max_new_tokens=max_length,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    # finally:
    #     # Remove the hook
    #     handle.remove()
        
    #     # Restore original forward methods
    #     for name, module in model.named_modules():
    #         if name in original_forward:
    #             module.forward = original_forward[name]
    
    return generated_text


def get_perpendicular_vector(probe_weights):
    """
    Calculate the perpendicular vector to the decision boundary of a linear probe
    
    Args:
        probe_weights: The weights from trainer.probe.linear.weight (shape: [num_classes, hidden_dim])
    
    Returns:
        perp_vec: Perpendicular vector to the decision boundary (shape: [1, hidden_dim])
    """
    # For binary classification, the perpendicular vector is directly the weight vector
    if probe_weights.shape[0] == 2:
        # Take the difference between positive and negative class vectors
        perp_vec = probe_weights[1] - probe_weights[0]
        # Normalize the vector
        perp_vec = perp_vec / torch.norm(perp_vec)
        # Reshape to [1, hidden_dim]
        perp_vec = perp_vec.unsqueeze(0)
        
    # For multi-class classification
    else:
        # Calculate the average direction between all class pairs
        n_classes = probe_weights.shape[0]
        perp_vec = torch.zeros_like(probe_weights[0])
        
        for i in range(n_classes):
            for j in range(i + 1, n_classes):
                # Get the direction between class i and j
                diff_vec = probe_weights[i] - probe_weights[j]
                # Add to the accumulator
                perp_vec += diff_vec
        
        # Normalize the accumulated vector
        perp_vec = perp_vec / torch.norm(perp_vec)
        # Reshape to [1, hidden_dim]
        perp_vec = perp_vec.unsqueeze(0)
    
    return perp_vec




def main():
    # Load Llama 3.1 model
    device = 'cuda:1'
    
    # model_path = "meta-llama/Llama-3.1-8B-Instruct"
    model_path = 'Qwen/Qwen2.5-Coder-7B-Instruct'
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=torch.float16, 
        # device_map="auto"
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    
    for index in range(1):
        vul = TYPE_OF_VUL[index]
        vul_prompt = Prompts_OF_VUL[index]
        
        
            
        ## not train
        # load the perp_vec
        # prompt = 'The following code is vulnerable. Please fix it.' + dataset[0]['func_src_before']
        with open(f'./scenity/cwe22_from_github/exampl2.py', 'r') as f:
            code = f.read()
        prompt = "The following code is vulnerable. Please fix it.\n" + code + "\nPut the fix code in the \\boxed{}." + 'Assistant:'
        
        perp_vec = torch.load(f'./analysis/perp_vec_Qwen2.5-Coder-7B-Instruct_CWE-{vul:03d}.pt')
        perp_vec = perp_vec.half()
        
         ### normal generate:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=1000,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
        normal_generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print('--------------------------------normal_generated_text--------------------------------')
        print(normal_generated_text)
        # # with open(f'normal_generated_text_CWE-{vul}.txt', 'w') as f:
        # #     f.write(normal_generated_text)
        

        generated_text = generate_with_modified_activation(model, tokenizer, prompt, perp_vec)
        print('--------------------------------modeified_generated_text--------------------------------')
        print(generated_text)
        
       
        
    # ave_train_acc = 0
    # ave_eval_acc = 0
    # for stats in train_stats:
    #     ave_train_acc += stats[-1]['train_acc_vul']
    #     ave_eval_acc += stats[-1]['val_acc_vul']
    # ave_train_acc /= len(train_stats)
    # ave_eval_acc /= len(train_stats)
    # wandb.log({"ave_train_acc": ave_train_acc, "ave_eval_acc": ave_eval_acc})
    # wandb.finish()

if __name__ == "__main__":
    main()