import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import transformers
import wandb
import time
import argparse
import os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np


TYPE_OF_VUL = [22, 78, 79, 89, 125, 190, 416, 476, 787]

Prompts_OF_VUL = {
            'cwe-022':'CWE-22, commonly called “Path Traversal,” is a vulnerability when an application fails to appropriately limit the paths users can access through a user-provided input.',
            'cwe-078':'CWE-78, means improper neutralization of special elements used in an OS command (OS command injection), constructs all or part of an OS command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended OS command when it is sent to a downstream component.',
            'cwe-079':'CWE-79, improper neutralization of input during web page generation (cross-site scripting). The code does not neutralize or incorrectly neutralizes user-controllable input before it is placed in output that is used as a web page that is served to other users.',
            'cwe-089':'CWE-89, Improper Neutralization of Special Elements used in an SQL Command (SQL Injection). The product constructs all or part of an SQL command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended SQL command when it is sent to a downstream component. Without sufficient removal or quoting of SQL syntax in user-controllable inputs, the generated SQL query can cause those inputs to be interpreted as SQL instead of ordinary user data.',
            'cwe-125':'CWE-125, Out-of-bounds Read. The product reads data past the end, or before the beginning, of the intended buffer.',
            'cwe-190':'CWE-190, Integer Overflow or Wraparound. The product performs a calculation that can produce an integer overflow or wraparound when the logic assumes that the resulting value will always be larger than the original value. This occurs when an integer value is incremented to a value that is too large to store in the associated representation. When this occurs, the value may become a very small or negative number.',
            'cwe-416':'CWE-416, Use After Free. The product reuses or references memory after it has been freed. At some point afterward, the memory may be allocated again and saved in another pointer, while the original pointer references a location somewhere within the new allocation. Any operations using the original pointer are no longer valid because the memory "belongs" to the code that operates on the new pointer.',
            'cwe-476':'CWE-476, NULL Pointer Dereference. The product dereferences a pointer that it expects to be valid but is NULL.',
            'cwe-787':'CWE-787, Out-of-bounds Write. The product writes data past the end, or before the beginning, of the intended buffer.',
}



def check_if_add_delete_together(jsonl_path):
    data = []
    if jsonl_path is not None:
        with open(jsonl_path, 'r') as f:
            for line in f:
                item = json.loads(line.strip())
                
                add_delete_together = False
                delete = False
                add = False
                if len(item['line_changes']['deleted']) > 0:
                    delete = True
                if len(item['line_changes']['added']) > 0:
                    add = True
                if delete and add:
                    add_delete_together = True
                if add_delete_together:
                    data.append(item)
    print('add_delete_together: ', len(data))
    return data

def get_second_to_last_layer(model, tokenizer, single_data, vul_prompt=None, few_shot_data_prompt=None, layer_index=-1):
    # Tokenize input
    # input_text =  single_data[keyword] + 'Is the following code vulnarable? Answer with yes or no.'
    #masks 
    if len(single_data['line_changes']['deleted']) > 0:
        change = single_data['line_changes']['deleted']
        deleted_start_char_ind = np.min([line_change['char_start'] for line_change in change])
        deleted_end_char_ind = np.max([line_change['char_end'] for line_change in change])
        deleted_input_start = f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_before'][:deleted_start_char_ind]}' '''
        deleted_input_start_ind = len(tokenizer(deleted_input_start)['input_ids'])
        deleted_input_end = f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_before'][:deleted_end_char_ind]}' '''
        deleted_input_end_ind = len(tokenizer(deleted_input_end)['input_ids'])
        deleted_gt_token = tokenizer(single_data['func_src_before'])['input_ids'][deleted_input_start_ind:deleted_input_end_ind]
        deleted_input_text =  f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_before']}' '''
        deleted_inputs = tokenizer(deleted_input_text, return_tensors="pt")
        # Forward pass with output of all hidden states
        with torch.no_grad():
            deleted_outputs = model(
                input_ids=deleted_inputs['input_ids'].to(model.device), 
                output_hidden_states=True
            )
        deleted_features = deleted_outputs.hidden_states[layer_index][0, deleted_input_start_ind:deleted_input_end_ind]
        

        if len(single_data['func_src_after'][deleted_start_char_ind:-1]) > 0:
            target_inputs = tokenizer(single_data['func_src_after'][deleted_start_char_ind:-1], return_tensors="pt")
            with torch.no_grad():
                target_outputs = model(
                    input_ids=target_inputs['input_ids'].to(model.device), 
                    output_hidden_states=True
                )
                deleted_target_features = target_outputs.hidden_states[layer_index][0, deleted_input_start_ind:-1]
        else:
            deleted_target_features = None


    if len(single_data['line_changes']['added']) > 0:
        change = single_data['line_changes']['added']
        added_start_char_ind = np.min([line_change['char_start'] for line_change in change])
        added_end_char_ind = np.max([line_change['char_end'] for line_change in change])
        added_input_start = f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_after'][:added_start_char_ind]}' '''
        added_input_start_ind = len(tokenizer(added_input_start)['input_ids'])
        added_input_end = f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_after'][:added_end_char_ind]}' '''
        added_input_end_ind = len(tokenizer(added_input_end)['input_ids'])
        added_gt_token = tokenizer(single_data['func_src_after'])['input_ids'][added_input_start_ind:added_input_end_ind]
        added_input_text =  f'''{vul_prompt} {few_shot_data_prompt} Fix the vulnerability in the following code: '{single_data['func_src_after']}' '''
        added_inputs = tokenizer(added_input_text, return_tensors="pt")
        # Forward pass with output of all hidden states
        with torch.no_grad():
            added_outputs = model(
                input_ids=added_inputs['input_ids'].to(model.device), 
                output_hidden_states=True
            )
            
        added_features = added_outputs.hidden_states[layer_index][0, added_input_start_ind:added_input_end_ind]
        
        if len(single_data['func_src_before'][added_start_char_ind:-1]) > 0:
            untarget_inputs = tokenizer(single_data['func_src_before'][added_start_char_ind:-1], return_tensors="pt")
            with torch.no_grad():
                untarget_outputs = model(
                    input_ids=untarget_inputs['input_ids'].to(model.device), 
                    output_hidden_states=True
                )
                added_untarget_features = untarget_outputs.hidden_states[layer_index][0, added_input_start_ind:-1]
        else:
            added_untarget_features = None
    
    
    if len(single_data['line_changes']['added']) > 0 and len(single_data['line_changes']['deleted']) > 0:
        len_train = min(len(deleted_features), len(added_features))
        return deleted_features[:len_train], added_features[:len_train], deleted_gt_token, added_gt_token
    elif len(single_data['line_changes']['added']) > 0 and len(single_data['line_changes']['deleted']) == 0:
        len_train = min(len(added_untarget_features), len(added_features))
        return added_untarget_features[:len_train], added_features[:len_train], None, added_gt_token
    elif len(single_data['line_changes']['added']) == 0 and len(single_data['line_changes']['deleted']) > 0:
        len_train = min(len(deleted_features), len(deleted_target_features))
        return deleted_features[:len_train], deleted_target_features[:len_train], deleted_gt_token, None
    else:
        print('no change')
        return None, None, None, None
    


class FeatureExtractor(nn.Module):
    def __init__(self, model, tokenizer, device, reasoning=False):
        super().__init__()
        # Remove the last classification layer
        self.model = model
        model.eval()
        self.tokenizer = tokenizer
        self.device = device
        self.reasoning = reasoning
        
    def forward(self, single_data, vul_prompt, few_shot_data_prompt):
        
        features_in, features_out, deleted_gt_token, added_gt_token = get_second_to_last_layer(self.model, self.tokenizer, single_data, vul_prompt, few_shot_data_prompt)

        return features_in.to(self.device), features_out.to(self.device), deleted_gt_token, added_gt_token

class LinearProbe(nn.Module):
    def __init__(self, input_dim=4096, num_classes=4096):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

class MLPProbe(nn.Module):
    def __init__(self, input_dim=4096, hidden_dims=[4096, 4096], num_classes=4096, dropout=0.2):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # 构建隐藏层
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        # 输出层
        layers.append(nn.Linear(prev_dim, num_classes))
        
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)
    
class FittingTrainer:
    def __init__(
        self,
        probe_type,
        classification_gen_head,
        num_classes,
        feature_dim,
        learning_rate=0.1,
        half_precision=False,
        num_epochs=8000,
        device='cuda:0' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.half_precision = half_precision
        self.num_epochs = num_epochs
        # Initialize feature extractor with frozen weights
     
        # for param in self.feature_extractor.parameters():
        #     param.requires_grad = False
        
        # Initialize linear probe
        if probe_type == 'linear':
            self.probe = LinearProbe(feature_dim, feature_dim).to(device)
            
        elif probe_type == 'mlp':
            self.probe = MLPProbe(input_dim=feature_dim, hidden_dims=[feature_dim, feature_dim], num_classes=feature_dim).to(device)
        
        if self.half_precision:
            self.probe.half()
        # Setup optimizer and loss
        self.optimizer = optim.SGD(self.probe.parameters(),  
                            lr=learning_rate, momentum=0.9, weight_decay=1e-3)  #
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        self.mse_loss = nn.MSELoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.classification_gen_head = classification_gen_head.to(torch.float32)
        
        # self.optimizer = optim.Adam(self.probe.parameters(), lr=learning_rate, weight_decay=1e-4, betas=(0.9, 0.999))
        
        
    def train_batch(self, features_for_train, features_for_supervise):
        '''
        features_for_train: the features of the before code [bs, feature_dim]
        features_for_supervise: the features of the after code [bs, feature_dim]
        '''
        
        self.probe.train()
        
        
        # Forward pass through linear probe
        self.optimizer.zero_grad()
        delta_feature= self.probe(features_for_train)
        final_feature = features_for_train + delta_feature
        logits_before = self.classification_gen_head(final_feature) # [bs, voc_size]
        logits_after = self.classification_gen_head(features_for_supervise) # [bs, voc_size]
        # print(outputs_before, outputs_after)
        # Calculate loss and backpropagate
        mse_loss = self.mse_loss(final_feature, features_for_supervise) 
        # kl_loss = self.kl_loss(logits_before, logits_after)
        # ce_loss = self.ce_loss(logits_before, )
        loss_batch = mse_loss # + kl_loss + ce_loss
        loss_batch.backward()
        torch.nn.utils.clip_grad_norm_(self.probe.parameters(), max_norm=1)
        self.optimizer.step()
 
        return loss_batch.item() #kl_loss.item(), ce_loss.item(), loss_batch.item()
    
     
    def train(self, train_dataloader):
        stats = []
        
        for epoch in range(self.num_epochs):
            epoch_loss = 0
            for batch in train_dataloader:
                features_before_train = batch['features_in']
                features_after_train = batch['features_out']

                # Train one epoch
                loss_batch = self.train_batch(features_before_train, features_after_train)
                
            # Evaluate
            print(f'Epoch: {epoch+1}/{self.num_epochs}')
            print('--------------------')
            loss_batch = loss_batch.item()
            wandb.log({"loss": loss_batch})
            stats.append({'epoch': epoch, 'loss_batch': loss_batch})
            print(f'Train Loss: {loss_batch}')
           
        return stats
    
    def eval(self, eval_dataloader):
        self.probe.eval()
        for batch in eval_dataloader:
            features_before_train = batch['features_in']
            features_after_train = batch['features_out']
            loss_batch = self.train_batch(features_before_train, features_after_train)


class CodeDataset(Dataset):
    def __init__(self, jsonl_path, extract_features=False, feature_extractor=None):
        self.data = []
        if jsonl_path is not None:
            with open(jsonl_path, 'r') as f:
                for line in f:
                    item = json.loads(line.strip())
                    self.data.append(item)
        self.extract_features = extract_features
        if extract_features and jsonl_path is not None:
            self.feature_extractor = feature_extractor
            
            self.few_shot_data_prompt = f"For example, {self.data[0]['func_src_before']} is vulnerable while {self.data[0]['func_src_after']} is not vulnerable. {self.data[1]['func_src_before']} is vulnerable while {self.data[1]['func_src_after']} is not vulnerable. "
            for index in range(2, len(self.data)):
                vul_type = self.data[index]['vul_type']
                vul_prompt = Prompts_OF_VUL[vul_type]
                features_in, features_out, deleted_gt_token, added_gt_token = self.feature_extractor(self.data[index], vul_prompt, self.few_shot_data_prompt)
                self.data[index]['features_in'] = features_in.float()
                self.data[index]['features_out'] = features_out.float()
                self.data[index]['deleted_gt_token'] = deleted_gt_token
                self.data[index]['added_gt_token'] = added_gt_token
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            dataset_slice = CodeDataset(None, self.extract_features, self.feature_extractor)
            dataset_slice.data = self.data[idx]
            if self.extract_features:
                dataset_slice.feature_extractor = self.feature_extractor
                dataset_slice.few_shot_data_prompt = self.few_shot_data_prompt
            return dataset_slice
        
        # 处理单个样本
        item = self.data[idx]
        if self.extract_features:
            # features_before, features_after = self.feature_extractor(item, self.vul_prompt, self.few_shot_data_prompt)
            return {
                'func_name': item['func_name'],
                'features_in': item['features_in'],
                'features_out': item['features_out'],
                'deleted_gt_token': item['deleted_gt_token'],
                'added_gt_token': item['added_gt_token'],
                'file_name': item['file_name'],
                'vul_type': item['vul_type'],
            }
        else:
            return {
                'func_name': item['func_name'],
                'func_src_before': item['func_src_before'],
                'func_src_after': item['func_src_after'],
                'line_changes': item['line_changes'],
                'char_changes': item['char_changes'],
                'file_name': item['file_name'],
                'vul_type': item['vul_type'],
            }

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    

def main(args):
    set_seed(2025)
    # Load Llama 3.1 model
    device = args.device
    
    # model_path = "meta-llama/Llama-3.1-8B-Instruct"
    model_path = 'Qwen/Qwen2.5-Coder-7B-Instruct'
    # model_path = 'Qwen/Qwen2.5-7B-Instruct'
    # model_path = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B'
    # model_path = 'Salesforce/codegen-350M-multi' # context length is not enough
    # model_path = 'Qwen/Qwen2.5-Coder-7B-Instruct'
    model_name = model_path.split('/')[-1]
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=torch.float16, 
        # device_map="auto"
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Train
    feature_extractor = FeatureExtractor(model, tokenizer, device, reasoning=args.reasoning).to(device)
    
    train_stats = []
    
    def my_collate(batch):
        batch_ = {}
        batch_['features_in'] = torch.cat([item['features_in'] for item in batch], dim=0)
        batch_['features_out'] = torch.cat([item['features_out'] for item in batch], dim=0)
        batch_['deleted_gt_token'] = [item['deleted_gt_token'] for item in batch]
        batch_['added_gt_token'] = [item['added_gt_token'] for item in batch]
        return batch_
    
    for _ in range(1):
        # index +=1
        index = _
        # index = args.vul_index
        vul = TYPE_OF_VUL[index]
        jsonl_path = f'./sven/data_train_val/train/cwe-{vul:03d}.jsonl'
        # check_if_add_delete_together(jsonl_path)
        
        dataset = CodeDataset(jsonl_path, extract_features=True, feature_extractor=feature_extractor)
        # save dataset
        # torch.save(dataset, f'dataset_{model_name}_CWE-{vul:03d}.pt')
        # dataset = torch.load(f'dataset_{model_name}_CWE-{vul:03d}.pt')
        # construct the train and eval dataset
        train_dataset = dataset[2:]  # 返回一个新的 CodeDataset 实例
        
        
        
        # minus the center vector
        center = False
        if center==True:
            ave_features = 0
            for item in train_dataset:   
                ave_features = torch.mean((item['features_in'].mean(dim=0) + item['features_out'].mean(dim=0)) / 2 / item['features_in'].shape[0], dim=0)
                print(ave_features.shape)
                item['features_in'] = item['features_in'] - ave_features
                item['features_out'] = item['features_out'] - ave_features
        
        
        # get the loader
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=64, 
            shuffle=True, 
            collate_fn=my_collate
        )
        feature_dim = dataset[-1]['features_in'].shape[-1]
        trainer = FittingTrainer(
            probe_type='mlp',
            classification_gen_head=model.lm_head,
            num_classes=feature_dim,  # Number of classes in your dataset
            feature_dim=feature_dim,  # Feature dimension from your backbone 3584 for qwen2.5, 4096 for llama3
            learning_rate=args.lr,
            device=device,
            num_epochs=args.num_epochs
        )
        wandb.init(project="model_editing", name=f"{time.time()}{model_path}_mlp_CWE-{vul}_center", config={"model": model_path})
        stats = trainer.train(train_dataloader)
        # save the model
        torch.save(trainer.probe.state_dict(), f'mlp_probe_{model_name}_CWE-{vul:03d}.pt')
        wandb.finish()
        train_stats.append(stats)  #list of list
        
    
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda:6')
    parser.add_argument('--num_epochs', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--reasoning', type=bool, default=False)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()
    main(args)