import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import transformers
from openai import OpenAI
import re
import base64
import anthropic


def get_second_to_last_layer_mean(model, tokenizer, dataset, keyword):
    # Tokenize input
    second_to_last_layer_list = []
    for single_data in dataset:
        input_text =  single_data[keyword] + 'Is the following code vulnarable?'
        inputs = tokenizer(input_text, return_tensors="pt")
        
        # Forward pass with output of all hidden states
        with torch.no_grad():
            outputs = model(
                input_ids=inputs['input_ids'], 
                output_hidden_states=True
            )
        
        # Get second to last layer's hidden states
        second_to_last_layer = outputs.hidden_states[-2][:, -1]
        # print(second_to_last_layer.shape)
        second_to_last_layer_list.append(second_to_last_layer)
    
    second_to_last_layer_list = torch.cat(second_to_last_layer_list, dim=0)
    print(second_to_last_layer_list.shape)
    return second_to_last_layer_list

def code_aug(ori_code, aug_para_len):
    len_ = len(ori_code)
    aug_code1 = ori_code[len_//aug_para_len:len_//aug_para_len*(aug_para_len-1)]

    return aug_code1

import re

def extract_answer(text):
    """
    Extract 'yes' or 'no' from text strings that start with 'Answer:' or '[Answer:'
    
    Args:
        text (str): Input text containing an answer
        
    Returns:
        str: 'yes' or 'no' if found, None otherwise
    """
    # Remove any whitespace and convert to lowercase
    text = text.strip().lower()
    
    # Pattern matches:
    # - Optional '[' at start
    # - 'Answer:' (with optional whitespace)
    # - Captures 'yes' or 'no'
    # - Followed by optional period, space, or end of string
    import re

def extract_vulnerability_answer(text):
    """
    Extract vulnerability answers from text for CWE-22, CWE-797, and CWE-76.
    Returns a dictionary with the findings.
    """
    answers = {}
    
    # Pattern for "Answer: yes/no" format
    answer_pattern = r"answer:\s*(yes|no)"
    answer_match = re.search(answer_pattern, text, re.IGNORECASE)
    if answer_match:
        answers['answer'] = answer_match.group(1).lower()
    
    # Patterns for specific CWE formats
    cwe_pattern = r"CWE-\d{1,3}:\s*(yes|no)"

    
    # Extract CWE-specific answers
    match = re.search(cwe_pattern, text, re.IGNORECASE)
    if match:
        answers['answer'] = match.group(1).lower()
            
    # Handle cases where the text starts with "No, the code does not have..."
    if text.lower().lstrip().startswith("no"):
        answers['answer'] = 'no'
    elif text.lower().lstrip().startswith("yes"):
        answers['answer'] = 'yes'
        
    if 'answer' in answers:
        return answers['answer']
    else:
        return None

def extract_answer_from_brackets(text):
    """
    从方括号中提取最后一个答案
    例如：
    "blah blah [yes] blah [no]" -> "no"
    "no brackets" -> None
    "[yes]" -> "yes"
    """
    import re
    
    # 找出所有方括号中的内容
    matches = re.findall(r'\[(.*?)\]', text)
    
    # 如果找到匹配项，返回最后一个
    if matches:
        return matches[-1].strip()
    
    return None
    

def sanity_check_llama(dataset, model_path, type_of_vul):
    pipeline = transformers.pipeline(
            "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
        )
    
    before, after, nonsense_before, nonsense_after = 0, 0, 0, 0
    index = 0
    dataset_size = len(dataset)-2
    for single_data in dataset:
        index += 1
        ##### reasoning
        prompt_before = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable, while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable, while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data['func_src_before']}'''. Answer the question with simply yes or no."
        prompt_after = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable, while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable, while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data['func_src_after']}'''. Answer the question with simply yes or no."
        # inputs = tokenizer(prompt_before, return_tensors="pt")
        # messages = [{'role':'user', 'content': prompt_before}]
        # tokenizer.apply_chat_template(
        #     messages, 
        #     tokenize=False, 
        #     add_generation_prompt=True
        # )

        # output = model.generate(inputs, max_new_tokens=5)
        # output_text = tokenizer.decode(output)
        output_text = pipeline(prompt_before, max_new_tokens=500)
        output_ = output_text[0]['generated_text'][len(prompt_before):]
        print(f'index:{index}'+output_+'\n')
        output__ = extract_vulnerability_answer(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'yes' in output__:
                before += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_before += 1
        else:
            nonsense_before += 1
            
        # inputs = tokenizer(prompt_after, return_tensors="pt")
        # output = model.generate(inputs, max_new_tokens=5)
        # output_text = tokenizer.decode(output)
        output_text = pipeline(prompt_after, max_new_tokens=500)
        output_ = output_text[0]['generated_text'][len(prompt_after):]
        print(f'index:{index}'+output_+'\n')
        output__ = extract_vulnerability_answer(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'no' in output__:
                after += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_after += 1
        else:
            nonsense_after += 1
            
    print('before, after, nonsense_before, nonsense_after',before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size)
    return before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size

def sanity_check_reasoning(dataset, model_path, type_of_vul):
    pipeline = transformers.pipeline(
            "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
        )
    # Use a pipeline as a high-level helper

    before, after, nonsense_before, nonsense_after = 0, 0, 0, 0
    index = 0
    dataset_size = len(dataset)-2
    for single_data in dataset:
        index += 1
        ##### reasoning
        prompt_before = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable, while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable, while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data['func_src_before']}'''. First think step by step, and then answer the question with simply yes or no within a bracket, like [yes] or [no]."
        prompt_after = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable, while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable, while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? '''{single_data['func_src_after']}'''. First think step by step, and then answer the question with simply yes or no within a bracket, like [yes] or [no]."
        
        messages_before = [
            {"role": "user", "content": prompt_before},
        ]
        messages_after = [
            {"role": "user", "content": prompt_after},
        ]
        
        output_text = pipeline(messages_before, max_new_tokens=2048)
        
        output_ = output_text[0]['generated_text'][1]['content']
        print(f'index:{index}', output_+'\n')
        output__ = extract_answer_from_brackets(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'yes' in output__:
                before += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_before += 1
        else:
            nonsense_before += 1
            
        # inputs = tokenizer(prompt_after, return_tensors="pt")
        # output = model.generate(inputs, max_new_tokens=5)
        # output_text = tokenizer.decode(output)
        output_text = pipeline(messages_after, max_new_tokens=2048)
        output_ = output_text[0]['generated_text'][1]['content']
        print(f'index:{index}'+output_+'\n')
        # output__ = extract_vulnerability_answer(output_)
        output__ = extract_answer_from_brackets(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'no' in output__:
                after += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_after += 1
        else:
            nonsense_after += 1
            
    print('before, after, nonsense_before, nonsense_after',before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size)
    return before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size

def sanity_check_anthopic(dataset, type_of_vul):
    before, after, nonsense_before, nonsense_after = 0, 0, 0, 0
    index = 1
    dataset_size = len(dataset)-2
    for single_data in dataset:
        index += 1
        prompt_before = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? {single_data['func_src_before']}. Answer with simply yes or no."
        prompt_after = f"{type_of_vul} For example, {dataset[0]['func_src_before']} is vulnerable while {dataset[0]['func_src_after']} is not vulnerable. {dataset[1]['func_src_before']} is vulnerable while {dataset[1]['func_src_after']} is not vulnerable. Does the following code has such vulnerability? {single_data['func_src_after']}. Answer with simply yes or no."
        

        output_ = claude_generate_response(prompt_before)
        output__ = extract_vulnerability_answer(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'yes' in output__:
                before += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_before += 1
        else:
            nonsense_before += 1
            
        output_ = claude_generate_response(prompt_after)
        output__ = extract_vulnerability_answer(output_)
        print(f'index:{index},extracted:{output__}')
        
        if output__ is not None:
            if 'no' in output__:
                after += 1
            elif not 'yes' in output__ and not 'no' in output__:
                nonsense_after += 1
        else:
            nonsense_after += 1
    print('before, after, nonsense_before, nonsense_after', before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size)
            
    return before/dataset_size, after/dataset_size, nonsense_before/dataset_size, nonsense_after/dataset_size

def claude_generate_response(text_input):
    client = anthropic.Anthropic()
    message = anthropic_client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[
            {
                "role": "user",
                "content": [
                   
                    {
                        "type": "text",
                        "text": text_input
                    }
                ],
            }
        ],
    )
    print(message)
    return message.content[0].text
    

def dataset_aug(ori_dataset):
    new_dataset = []
    for sig_data in ori_dataset:
        sig_data['func_src_before_aug1'] = code_aug(sig_data['func_src_before'], 3)
        sig_data['func_src_after_aug1'] = code_aug(sig_data['func_src_after'], 3)
        sig_data['func_src_before_aug2'] = code_aug(sig_data['func_src_before'], 4)
        sig_data['func_src_after_aug2'] = code_aug(sig_data['func_src_after'], 4)
        sig_data['func_src_before_aug3'] = code_aug(sig_data['func_src_before'], 5)
        sig_data['func_src_after_aug3'] = code_aug(sig_data['func_src_after'], 5)
        new_dataset.append(sig_data)
    return new_dataset
   

def feature_ana(dataset, model_path, model, tokenizer,):
    import pdb;pdb.set_trace()
    auged_dataset = dataset_aug(dataset)
    
    # Get and print feature mean
    features_before = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_before')
    features_after = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_after')
    # import pdb; pdb.set_trace()
    # l1_norm = torch.norm(features_before - features_after, p=1, dim=-1)
    # l2_norm = torch.norm(features_before - features_after, p=2, dim=-1)
    # linf_norm = torch.norm(features_before - features_after, p=float('inf'), dim=-1)
    # l1_var = torch.norm(features_before - torch.mean(features_before, dim=0), p=1, dim=-1)
    # l2_var = torch.norm(features_before - torch.mean(features_before, dim=0), p=2, dim=-1)
    # linf_var = torch.norm(features_before - torch.mean(features_before, dim=0), p=float('inf'), dim=-1)
    # print(l1_norm.mean(), l2_norm.mean(), linf_norm.mean())
    # print(l1_var.mean(), l2_var.mean(), linf_var.mean())
    # print("Second-to-Last Layer Feature Mean:", mean_features)
    # features_aug1_before = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_before_aug1')
    # features_aug1_after = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_after_aug1')
    # features_aug2_before = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_before_aug2')
    # features_aug2_after = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_after_aug2')
    # features_aug3_before = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_before_aug3')
    # features_aug3_after = get_second_to_last_layer_mean(model, tokenizer, dataset, 'func_src_after_aug3')
    
    
    features_before_normalized = F.normalize(features_before, p=2, dim=-1)
    features_after_normalized = F.normalize(features_after, p=2, dim=-1)
    
    # Calculate cosine similarity
    cosine_sim = torch.sum(features_before_normalized * features_after_normalized, dim=-1)
    print(cosine_sim)
    
    ## inter similarity
    diff = features_before - features_after
    diff_sim = diff@diff.T/torch.norm(diff, dim=-1)/torch.norm(diff, dim=-1)
    num_ = len(dataset)
    diff_mean = torch.triu(diff_sim, diagonal=1).sum()/(num_*num_-num_)
    print(diff_mean)
    import pdb; pdb.set_trace()
    

def main():
    # Load Llama 3.1 model
    model_path = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B'
    
    # model_path = 'Qwen/Qwen2.5-Coder-7B-Instruct'
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_path, 
    #     torch_dtype=torch.float16, 
    #     # device_map="auto"
    # ).to('cuda:0')
    # tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    
    # Example input text
    vul_dict = {}
    vul_list = [22, 78, 79, 89, 125, 190, 416, 476, 787]
    vul_list = [22, ]
    
    type_of_vul = [
            'CWE-22, commonly called “Path Traversal,” is a vulnerability when an application fails to appropriately limit the paths users can access through a user-provided input.',
            'CWE-78, means improper neutralization of special elements used in an OS command (OS command injection), constructs all or part of an OS command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended OS command when it is sent to a downstream component.',
            'CWE-79, improper neutralization of input during web page generation (cross-site scripting). The code does not neutralize or incorrectly neutralizes user-controllable input before it is placed in output that is used as a web page that is served to other users.',
            'CWE-89: Improper Neutralization of Special Elements used in an SQL Command (SQL Injection). The product constructs all or part of an SQL command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended SQL command when it is sent to a downstream component. Without sufficient removal or quoting of SQL syntax in user-controllable inputs, the generated SQL query can cause those inputs to be interpreted as SQL instead of ordinary user data.',
            'CWE-125: Out-of-bounds Read. The product reads data past the end, or before the beginning, of the intended buffer.',
            'CWE-190: Integer Overflow or Wraparound. The product performs a calculation that can produce an integer overflow or wraparound when the logic assumes that the resulting value will always be larger than the original value. This occurs when an integer value is incremented to a value that is too large to store in the associated representation. When this occurs, the value may become a very small or negative number.',
            'CWE-416: Use After Free. The product reuses or references memory after it has been freed. At some point afterward, the memory may be allocated again and saved in another pointer, while the original pointer references a location somewhere within the new allocation. Any operations using the original pointer are no longer valid because the memory "belongs" to the code that operates on the new pointer.',
            'CWE-476: NULL Pointer Dereference. The product dereferences a pointer that it expects to be valid but is NULL.',
            'CWE-787: Out-of-bounds Write. The product writes data past the end, or before the beginning, of the intended buffer.',
            
        ]
    acc_before = 0
    acc_after = 0
    for index in range(len(vul_list)):
        vul = vul_list[index]
        vul_prompt = type_of_vul[index]
        jsonl_path = f'./data_train_val/train/cwe-{str(vul).zfill(3)}.jsonl'
        dataset = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                _ = json.loads(line)
                dataset.append(_)
        print('dataset size',len(dataset))
        # results = sanity_check_llama(dataset, model_path, vul_prompt)
        results = sanity_check_reasoning(dataset, model_path, vul_prompt)
        
        # results = sanity_check_anthopic(dataset, type_of_vul)
        acc_before += results[0]
        acc_after += results[1]
    print(acc_before/(len(vul_list)), acc_after/(len(vul_list)))
    
    # sanity_check_anthopic(dataset)
    # Example usage

    # feature_ana(dataset, model_path, model, tokenizer,)
    

if __name__ == "__main__":
    
    # # Example usage with your sample texts
    # sample_texts = [
    #     'index:50 CWE-22 is a vulnerability... Answer: yes. Reason: The code allows...',
    #     'CWE-22: no. Reason: The code is not vulnerable...',
    #     'No, the code does not have such vulnerability...',
    #     '[Answer: no. Reason: The `static_file()` function is used to serve static files, it does not directly access the file system. The `path` parameter is ',
    #     '(Please fill in the answer and reason below.)\n [Answer: no. Re',
        
    #     ]

    # for text in sample_texts:
    #     print(f"Text analysis: {extract_vulnerability_answer(text)}")
    

    main()