
import os
import csv
import argparse
import random
import pandas as pd
from collections import Counter, defaultdict
import torch
import datasets
from eval.utils import get_next_word_predictions_with_guidance, dynamic_import_function, load_hf_score_lm_and_tokenizer, load_hooked_lm_and_tokenizer
from hooked_models.utils import get_act_name
import numpy as np
from functools import partial
from transformers import PreTrainedTokenizerBase
from hooked_models.utils import seed_torch, get_bbh_prompt, get_mmlu_prompt

Activation_cache = defaultdict(list)

def layer_all_patch_hook(value, hook, patched_values):
    try:
        value = patched_values.to(value.device).to(value.dtype)
    except Exception as e:
        print(f'Error in hook {hook}', e)
    return value

def layer_patch_hook(value, hook, neurons, patched_values):
    try:
        if not isinstance(patched_values, torch.Tensor):
            patched_values = torch.tensor(patched_values)
        patched_values = patched_values.to(value)
        value[..., neurons] = patched_values
    except Exception as e:
        print(f'Error in hook {hook}', e)
    return value

def layer_cache_hook(value, hook, neurons, patched_values):
    try:
        global Activation_cache
        if not isinstance(patched_values, torch.Tensor):
            patched_values = torch.tensor(patched_values)
        patched_values = patched_values.to(value)
        Activation_cache[hook.layer()].append((patched_values-value[..., neurons]).detach().cpu())
    except Exception as e:
        print(f'Error in hook {hook}', e)
    return value

def patch_hook(activation, hook, neuron, value):
    if isinstance(value, torch.Tensor):
        value = value.clone().detach()
    else:
        value = torch.tensor(value)
    activation[...,neuron] = value.to(activation.device)
    return activation

def perturb_hook(activation, hook, neuron, value):
    try:
        activation[...,neuron] += (torch.randn_like(value) * 2 * value).to(activation.device)
    except Exception as e:
        print(f'Error in hook {hook}', e)
    return activation

def flip_hook(activation, hook, neuron):
    activation[...,neuron] *= -1
    return activation

def is_same_tokenizer(
    tokenizer: PreTrainedTokenizerBase,
    other_tokenizer: PreTrainedTokenizerBase,
) -> bool:
    """Check if two tokenizers are the same."""
    return tokenizer is other_tokenizer or (
        tokenizer.__class__ == other_tokenizer.__class__
        and tokenizer.get_vocab() == other_tokenizer.get_vocab()
    )

def tensor_intersect(a, b):
    b_set = set([(x, y) for x, y in b.tolist()])
    a_list = a.tolist()
    return torch.tensor([[x, y] for x, y in a_list if (x, y) in b_set])
def tensor_substract(a, b):
    b_set = set([(x, y) for x, y in b.tolist()])
    a_list = a.tolist()
    return torch.tensor([[x, y] for x, y in a_list if (x, y) not in b_set])

def get_save_name(args):
    save_name = 'vanilla'
    if args.patch_mean:
        save_name = 'patch_mean'
    elif args.patch_zero:
        save_name = 'patch_zero'
    elif args.add_noise:
        save_name = 'add_noise'
    elif args.patch_flip:
        save_name = 'flip'
    elif args.guided_generation:
        save_name = f'guided_by_{os.path.basename(args.blue_peft_path[-1])}' if args.blue_peft_path is not None else f'guided_by_{os.path.basename(args.model_name_or_path)}'
    if args.index_path:
        save_name += f'_idx_{os.path.basename(args.index_path).split(".")[0]}'
        if args.index_path != args.value_path:
            save_name += f'_value_{os.path.basename(args.value_path).split(".")[0]}'
    if args.ignore_index_path:
        save_name += f'_sub_{os.path.basename(args.ignore_index_path).split(".")[0]}' 
    if args.intersect_index_path:
        save_name += f'_intersect_{os.path.basename(args.intersect_index_path).split(".")[0]}' 
    if args.generation_startswith != '':
        save_name += f'_startswith_{args.generation_startswith}'
    if args.use_random_neurons:
        save_name += f'_random_neurons'
    if args.sliding_window is not None:
        save_name += f'_window_{args.sliding_window}'
    return save_name

def next_token_acc(preds: list, targets: list):
    tot = len(preds)
    cnt = 0
    for pred, target in zip(preds, targets):
        if pred == target:
            cnt += 1
    return cnt / tot
   
def main(args):
    seed_torch()
    os.makedirs(args.save_dir, exist_ok=True)

    print("loading data and model...")
    eval_data = datasets.load_dataset('json', data_files=args.dataset)["train"]
    # eval_data = get_bbh_prompt(args.dataset, max_num_examples_per_task=args.num_samples)
    # prompts = get_mmlu_prompt(args.dataset, n_instances=2)
    # breakpoint()
    prompts = []
    chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
    for example in eval_data:
        prompt = example['prompt']
        if args.use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
        prompts.append(prompt + args.generation_startswith)

    prompts = random.sample(prompts[-3*args.num_samples:], args.num_samples)
    if args.model_name_or_path is not None:
        red_model, tokenizer = load_hooked_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
            load_in_8bit=args.load_in_8bit,
            device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
            peft_name_or_path=args.red_peft_path
        )
        red_peft_type = getattr(red_model, "peft_type", "Base_Red")

        red_predictions, red_probs = get_next_word_predictions_with_guidance(
            model=red_model,
            tokenizer=tokenizer,
            prompts=prompts,
            batch_size=args.eval_batch_size if args.eval_batch_size else 1,
        )
        red_probs = torch.tensor(red_probs)
        
        blue_model, tokenizer = load_hooked_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
            load_in_8bit=args.load_in_8bit,
            device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
            peft_name_or_path=args.blue_peft_path
        )
        blue_peft_type = getattr(blue_model, "peft_type", "Base_Blue")

        x, blue, red = [], [], []
        hook_fn = lambda v, h: v
        save_name = get_save_name(args)
        red_name = os.path.basename(args.red_peft_path[-1]) if args.red_peft_path else os.path.basename(args.model_name_or_path)
        blue_name = os.path.basename(args.blue_peft_path[-1]) if args.blue_peft_path else os.path.basename(args.model_name_or_path)

        if args.index_path:
            _, index, base_mean, peft_mean, base_std, peft_std = torch.load(args.index_path)
            if args.value_path != args.index_path:
                _, _, base_mean, peft_mean, base_std, peft_std = torch.load(args.value_path)
        else:
            index = []
        if args.ignore_index_path:
            _, ignore_index, _, _, _, _ = torch.load(args.ignore_index_path)
        if args.intersect_index_path:
            _, intersect_index, _, _, _, _ = torch.load(args.intersect_index_path)
            
        for topk in args.topk_ablate:
            x.append(topk)
            if args.sliding_window is not None:
                topk_index = index[topk-args.sliding_window:topk]
            else:
                topk_index = index[:topk]
            if args.ignore_index_path:
                topk_index = tensor_substract(topk_index, ignore_index[:topk_index.shape[0]])
            if args.intersect_index_path:
                topk_index = tensor_intersect(topk_index, intersect_index[:topk_index.shape[0]])
            if args.use_random_neurons:
                counts = Counter(topk_index[:, 0].tolist())
                print('Number of neuron each layer: ', counts)
                topk_index = []
                for layer, num in counts.items():
                    neurons = random.sample(range(blue_model.config.intermediate_size), num)
                    topk_index += [[layer, neuron] for neuron in neurons]
                topk_index = torch.tensor(topk_index)
                counts = Counter(topk_index[:, 0].tolist())
                print('Number of neuron each layer after sample: ', counts)

            print(f"running with {topk} neurons replaced by {red_name} model")
            blue_predictions, blue_probs = get_next_word_predictions_with_guidance(
                model=blue_model,
                tokenizer=tokenizer,
                prompts=prompts,
                batch_size=args.eval_batch_size if args.eval_batch_size else 1,
                guided_model=red_model,
                index=topk_index,
                hook_fn=layer_patch_hook
            )
            blue_probs = torch.tensor(blue_probs)
            kl_div = torch.nn.functional.kl_div(red_probs.log(), blue_probs.log(), reduction='batchmean', log_target=True)
            print('- log KL divergence: ', -kl_div.log().item())
            print('Next token accuracy: ', next_token_acc(blue_predictions, red_predictions))
            # folder_name = f'{red_name}_vs_{blue_name}' if red_name != blue_name else red_name
            # if not args.cache_difference:  
            #     columns = [
            #         'Prompt',
            #         f'Llama7B {red_peft_type}',
            #         'Cost/Reward',
            #         f'Llama7B {blue_peft_type}',
            #         'Cost/Reward',
            #     ]    

            #     table = []
            #     for i in range(len(prompts)):
            #         row = (prompts[i], red_outputs[i], red_scores[i], blue_outputs[i], blue_scores[i])
            #         table.append(row)

            #     table_output_dir = os.path.join(
            #         args.save_dir,
            #         folder_name
            #     )
            #     os.makedirs(table_output_dir, exist_ok=True)
            #     output_file = os.path.join(table_output_dir, f'top{topk}_{save_name}_table.csv')
            #     with open(output_file, mode='w', encoding='utf-8') as f:
            #         writer = csv.writer(f)
            #         writer.writerow(columns)
            #         writer.writerows(table)
                    
            #     print('The following analysis is under the preference of the cost/reward model.',)
            #     cost_red = np.asarray([row[2] for row in table])
            #     cost_blue = np.asarray([row[4] for row in table])

            #     print(f'Average cost/reward of {red_name}: {cost_red.mean()}')
            #     print(f'Average cost/reward of {blue_name}: {cost_blue.mean()}')
            #     red.append(cost_red.mean())
            #     blue.append(cost_blue.mean())
            #     if not is_manipulated:
            #         break
            # else:
            #     activation_output_dir = os.path.join(
            #         'Alignment/output/activations',
            #         folder_name
            #     )
            #     os.makedirs(activation_output_dir, exist_ok=True)
            #     torch.save((Activation_cache, masks), os.path.join(activation_output_dir, f'top{topk}_{save_name}.pt'))
            #     Activation_cache.clear()
        # try:
        #     if is_manipulated and not args.cache_difference:
        #         df = pd.DataFrame({'topk': x, f'{red_name}': red, f'{blue_name}': blue})
        #         df.to_csv(os.path.join(table_output_dir, f'{save_name}.csv'))
        # except Exception as e:
        #     print(e)

        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num_samples",
        type=int,
        default=100,
        help="Number of samples to evaluate.",
    )
    parser.add_argument(
        "--topk_ablate",
        type=int,
        default=1,
        nargs='+',
        help="Number of top different neurons to ablate.",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="",
        help="Dataset to evaluate.",
    )
    parser.add_argument(
        "--save_dir",
        type=str, 
        default="results/alpaca_farm")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate the predictions.",
    )
    parser.add_argument(
        "--tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--cost_model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate cost scores for predictions.",
    )
    parser.add_argument(
        "--cost_tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--openai_engine",
        type=str,
        default=None,
        help="If specified, we will use the OpenAI API to generate the predictions.",
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--max_new_tokens", 
        type=int, 
        default=256, 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit",
        action="store_true",
        help="Load model in 8bit mode, which will reduce memory and speed up inference.",
    )
    parser.add_argument(
        "--gptq",
        action="store_true",
        help="If given, we're evaluating a 4-bit quantized GPTQ model.",
    )
    parser.add_argument(
        "--use_chat_format", 
        action="store_true", 
        help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function", 
        type=str, 
        default="eval.templates.create_prompt_with_tulu_chat_format", 
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )
    parser.add_argument(
        "--use_vllm",
        action="store_true",
        help="If given, we will use vLLM to generate the predictions - much faster.",
    )
    parser.add_argument(
        "--red_peft_path", 
        nargs='+',
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--blue_peft_path", 
        nargs='+',
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--index_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--ignore_index_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--intersect_index_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--value_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--use_random_neurons", 
        action="store_true",
        help="Use random neurons instead of found."
    )
    parser.add_argument(
        "--add_noise", 
        action="store_true",
        help="Add random Gaussian noise to selected neurons."
    )
    parser.add_argument(
        "--patch_mean", 
        action="store_true",
        help="Patch with neuron activation mean."
    )
    parser.add_argument(
        "--patch_zero", 
        action="store_true",
        help="Patch with 0."
    )
    parser.add_argument(
        "--patch_flip", 
        action="store_true",
        help="Flip neuron activation."
    )
    parser.add_argument(
        "--guided_generation", 
        action="store_true",
        help="Guided generation using aligned model activation."
    )
    parser.add_argument(
        "--generation_startswith", 
        type=str, 
        default='', 
        help="Generate completion start with given string."
    )
    parser.add_argument(
        "--sliding_window", 
        type=int, 
        default=None, 
        help="If specified, using the neurons ranked between topk to topk+sliding_window."
    )
    parser.add_argument(
        "--cache_difference", 
        action="store_true",
        help="Cache model activation difference between red and blue model on red model generation."
    )
    args = parser.parse_args()

    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)