import os, re, json
import itertools
from tqdm import tqdm

import torch, numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

import argparse

from nethook import TraceDict, get_module, get_parameter

# Include prompt creation helper functions
from utils import make_valid_path_name
from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.eval_utils import *
from compute_average_activations import get_mean_head_activations
from compute_indirect_effect import compute_indirect_effect
from evaluate_function_pointer import compute_universal_function_pointer



def portability_eval(dataset, fp_vector, edit_layer, model, model_config, tokenizer, n_eval_templates=15, seed=42):
    # Pre-define portability template parts
    all_prefixes = [{'input': 'A:', 'output': 'B:', 'instructions': ''},
                    {'input': 'input:', 'output': 'output:', 'instructions': ''},
                    {'input': 'Input:', 'output': 'Output:', 'instructions': ''},
                    {'input': 'In:', 'output': 'Out:', 'instructions': ''},
                    {'input': 'question:', 'output': 'answer:', 'instructions': ''},
                    {'input': 'Question:', 'output': 'Answer:', 'instructions': ''},
                    {'input': '', 'output': ' ->', 'instructions': ''},
                    {'input': '', 'output': ' :', 'instructions': ''},
                    {'input': 'text:', 'output': 'label:', 'instructions': ''},
                    {'input': 'x:', 'output': 'f(x):', 'instructions': ''},
                    {'input': 'x:', 'output': 'y:', 'instructions': ''},
                    {'input': 'X:', 'output': 'Y:', 'instructions': ''}]

    all_separators=[{'input': ' ', 'output': '', 'instructions': ''},
                    {'input': ' ', 'output': '\n', 'instructions': ''},
                    {'input': ' ', 'output': '\n\n', 'instructions': ''},
                    {'input': '\n', 'output': '\n', 'instructions': ''},
                    {'input': '\n', 'output': '\n\n', 'instructions': ''},
                    {'input': '\n\n', 'output': '\n\n', 'instructions': ''},
                    {'input': ' ', 'output': '|', 'instructions': ''},
                    {'input': '\n', 'output': '|', 'instructions': ''},
                    {'input': '|', 'output': '\n', 'instructions': ''},
                    {'input': '|', 'output': '\n\n', 'instructions': ''}]

    # Choose a random subset of n_eval_templates combinations
    all_combinations = list(itertools.product(all_prefixes, all_separators))
    set_seed(seed)
    random_combos = [list(x) for x in np.array(all_combinations)[np.random.choice(np.arange(len(all_combinations)), n_eval_templates, replace=False)]]

    zs_res_dict = {}
    fs_res_dict = {}
    fs_shuffled_res_dict = {}
    templates = []
    for i,(p,s) in enumerate(random_combos):

        template_repr = p['input'] + '{}' + s['input'] + p['output'] + '{}' + s['output']
        templates.append(template_repr)

        set_seed(seed)
        # FS Eval + Filtering
        fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=10, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, prefixes=p, separators=s)
        filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
        fs_res_dict[i] = fs_results

        # ZS Eval
        zs_res_dict[i] = n_shot_eval(dataset, fp_vector, edit_layer, 0, model, model_config, tokenizer, filter_set=filter_set, prefixes=p, separators=s)

        # ZS Eval
        fs_shuffled_res_dict[i] = n_shot_eval(dataset, fp_vector, edit_layer, 10, model, model_config, tokenizer, filter_set=filter_set, prefixes=p, separators=s, shuffle_labels=True)
    
    return fs_res_dict, zs_res_dict,fs_shuffled_res_dict,  templates


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
    parser.add_argument('--n_eval_templates', help='Number of templates to evaluate with', required=True, type=int, default=15)
    parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=9) # 

    parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function pointer', required=False, type=int, default=10)
    parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
    parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='dataset_files')
    parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='results')
    parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=5678)
    parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
    parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None)
    parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)    
    parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10)
    parser.add_argument('--n_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25)
    parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
    parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})    
    
    # parser.add_argument('--shuffle_labels', help="Whether to evaluate intervention on few-shot shuffled label setting", type=bool, required=False, default=False)
        
    args = parser.parse_args()  

    dataset_name = args.dataset_name
    model_name = args.model_name
    root_data_dir = args.root_data_dir
    save_path_root = f"{args.save_path_root}/{dataset_name}"
    seed = args.seed
    device = args.device
    mean_activations_path = args.mean_activations_path
    # indirect_effect_path = args.indirect_effect_path
    n_top_heads = args.n_top_heads
    eval_edit_layer = args.edit_layer

    test_split = args.test_split
    n_shots = args.n_shots
    n_trials = args.n_trials

    prefixes = args.prefixes 
    separators = args.separators
    
    n_eval_templates = args.n_eval_templates

    print(args)

    # Load Model & Tokenizer
    torch.set_grad_enabled(False)
    print("Loading Model")
    model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)

    if args.edit_layer == -1: # sweep over all layers if edit_layer=-1
        eval_edit_layer = [0, model_config['n_layers']]

    # Load the dataset
    print("Loading Dataset")
    set_seed(seed)
    dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)

    if not os.path.exists(save_path_root):
        os.makedirs(save_path_root)    

    # Load or Re-Compute mean_head_activations
    if mean_activations_path is not None and os.path.exists(mean_activations_path):
        mean_activations = torch.load(mean_activations_path)
    elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
        mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
        mean_activations = torch.load(mean_activations_path)        
    else:
        print("Computing Mean Activations")
        set_seed(seed)
        mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 
                                                     n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
        args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
        torch.save(mean_activations, args.mean_activations_path)
        
    # Compute Function Pointer
    fp, top_heads = compute_universal_function_pointer(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads)   
    
    print("Computing Portability")
    fs_res_dict, zs_res_dict,fs_shuffled_res_dict, templates = portability_eval(dataset, fp, eval_edit_layer, model, model_config, tokenizer, n_eval_templates=n_eval_templates)

    args.templates = templates

    save_path_root = f"{args.save_path_root}_port/{dataset_name}"
    if not os.path.exists(save_path_root):
        os.makedirs(save_path_root)   
    
    fs_results_file_name = make_valid_path_name(f'{save_path_root}/fs_port_eval.json')
    args.fs_results_file_name = fs_results_file_name
    with open(fs_results_file_name,'w') as fs_results_file:
        json.dump(fs_res_dict, fs_results_file,indent=2)

    fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_port_eval.json')
    args.fs_shuffled_results_file_name = fs_shuffled_results_file_name
    with open(fs_shuffled_results_file_name,'w') as fs_shuffled_results_file:
        json.dump(fs_shuffled_res_dict, fs_shuffled_results_file,indent=2)

    zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_port_eval.json')
    args.zs_results_file_name = zs_results_file_name
    with open(zs_results_file_name,'w') as zs_results_file:
        json.dump(zs_res_dict, zs_results_file,indent=2)

    args_file_name = make_valid_path_name(f'{save_path_root}/port_eval_args.txt')
    with open(args_file_name, 'w') as arg_file:
        json.dump(args.__dict__, arg_file, indent=2)