import gc
import os
import torch
import yaml
import json
import argparse
import numpy as np

from scripts.llama_child import Llama
from scripts.mistral_child import Mistral
from scripts.smollm_child import SmolLM

import torch.nn.functional as F
from torch.utils.data import DataLoader
from trl import DataCollatorForCompletionOnlyLM
from peft import get_peft_model
from transformers import GenerationConfig


device='cuda'
gradient_store = {}
gradient_stats = {}
mean_gradient_store = {}
gradient_stats_layer = {}
variance_batch = {}
probas = {}
losses = {}
cosines = {}
dots = {}
task_outputs = {}
ppls = {}


def get_dataloader(data, tokenizer, assistant_start_token, batch_size=1, include_target_label=True):
    """Get dataloader for feeding single examples to calculate metrics
    """
    
    mycollator = DataCollatorForCompletionOnlyLM(response_template=assistant_start_token,
                                                     tokenizer=tokenizer)
    if not include_target_label: # using model generated answer to calculate metrics
        sample_ls = [tokenizer(ex.split(assistant_start_token)[0] + f"\n {assistant_start_token}", padding=True, add_special_tokens=False) for ex in data['text']]
    else:
        sample_ls = [tokenizer(ex, padding=True, add_special_tokens=False) for ex in data['text']]
        
    loader = DataLoader(
            sample_ls,
            shuffle=False,
            collate_fn=mycollator,
            batch_size=batch_size
        )
    
    return loader


def mask_non_answer_tokens(model_prefix, batch):
    """Mask additional tokens not masked by the collator; this affects the loss and gradient calculation.
    """
    
    exclude_tok_ids = None
    if model_prefix == 'mistral':
        exclude_tok_ids = [2,29473] # <\s> that does not get masked and ''
    elif model_prefix == 'llama':
        exclude_tok_ids = [271] # \n\n which is at the beginning of the answer; it already excludes eos tokens
    elif model_prefix == 'smollm':
        exclude_tok_ids = [0,2,198] # <|endoftext|>(pad token), <|im_end|> (eos token), \n

    exclude_tok_idxs = sum([batch['labels'] == i for i in exclude_tok_ids]).bool()
    batch['labels'][exclude_tok_idxs] = -100

    return batch


def generate_text(model, tokenizer, batch, num_rep):
    """Prompt the model {num_rep} times to gte model response
    """

    # org_text = tokenizer.decode(batch['input_ids'][0])
    # gen_input = org_text.split(assistant_start_token)[0] + assistant_start_token
    # gen_input_tok = tokenizer(gen_input, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True)

    generation_config = GenerationConfig(
                                        #temperature=0.5,
                                        max_new_tokens=30,
                                        pad_token_id=tokenizer.pad_token_id,
                                        eos_token_id=tokenizer.eos_token_id,
                                        num_beam=4,
                                        do_sample=False #True
                                        )
    batch = {k: v.repeat(num_rep,1).to(device) for k, v in batch.items()} #  gen_input_tok.items()
    out = model.generate(**batch, generation_config=generation_config)
    generated_text = tokenizer.batch_decode(out)
    
    del batch
    del out

    return generated_text


# def clean_model_output(out, assistant_start_token, eos_token):
#     """Extract the label from model generated answer
#     """
    
#     out = out.split(assistant_start_token)[-1]
#     out = out.split('####')[-1] # a few datasets with #### as the final answer designator
#     out = out.strip('\n').split('\n')[0] # take the very first answer before the line change
#     out = out.split(eos_token)[0].strip().lower() # misc cleaning: remove end of sequence tok, remove space, lower case
#     out = out.strip('.').strip() # remove '.' in case of short answers
    
#     return out


# def clean_target_output(tar):
#     """Extract the target label from dataset
#     """
    
#     tar = str(tar)
#     tar = tar.split('####')[-1]
#     tar = tar.strip('.').replace(' ','').strip().lower()

#     return tar


def cosine_similarity(v1, v2):
    """Input: dictionary of model layer name: gradients
       Output: cosine similarity value, dot product value
    """

    num = 0
    norm_v1 = 0
    norm_v2 = 0
    for name in v1:
        if name not in v2:
            print(f"{name} not in v2, skipping this layer")
            continue
        num += torch.dot(v1[name].flatten(), v2[name].flatten())
        norm_v1 += (v1[name] ** 2).sum()
        norm_v2 += (v2[name] ** 2).sum()

    norm_v1 = torch.sqrt(norm_v1)
    norm_v2 = torch.sqrt(norm_v2)
    cosine_sim = float(num / (norm_v1 * norm_v2))
    dot_val = float(num)

    return cosine_sim, dot_val


def calculate_cosine_similarity(key):
    """Iterate through example gradients (fixed by batch size) to calculate local cosine similarity among the examples
    """

    sample_keys = list(gradient_store[key].keys()) 
    cosines[key] = {}
    cosines[key]['all'] = []
    dots[key] = {}
    dots[key]['all'] = []

    for idx, ex_i in enumerate(sample_keys):
        cosines[key][str(ex_i)] = {}
        dots[key][str(ex_i)] = {}
        for ex_j in sample_keys[idx:]:
            cosines[key][str(ex_i)][str(ex_j)] = {}
            dots[key][str(ex_i)][str(ex_j)] = {}
            temp1 = gradient_store[key][ex_i]
            temp2 = gradient_store[key][ex_j]
            cos_val, dot_val = cosine_similarity(temp1, temp2) 
            # store cosine
            cosines[key][str(ex_i)][str(ex_j)] = cos_val
            cosines[key]['all'].append(cos_val)
            # store dot
            dots[key][str(ex_i)][str(ex_j)] = dot_val
            dots[key]['all'].append(dot_val)
    
    print("Done with cosine similarity calc!")


def calculate_model_proba(key, batch, logits, example_num):
    """Calculate the probability assigned to the highest proba  
    """

    # 1. get the target idxs
    target_idx = batch['labels'][batch['labels'] != -100] # target tokens
    target_pos = (batch['labels'] != -100).nonzero()[:,1] # target token positions

    # 2. get the softmax probas at the target positions
    softmaxed_proba = F.softmax(logits[:, target_pos - 1, :], dim = -1).squeeze(0) # -1 to shift tokens to the left by one
    model_pred_proba, model_pred = torch.max(softmaxed_proba, dim=-1)
   
    rows = np.arange(0,len(target_idx))
    target_proba = softmaxed_proba[rows, target_idx]

    # 3. Calculate relevant metrics
    avg_error = torch.mean(1 - target_proba).item()
    avg_confidence = torch.mean(model_pred_proba).item()

    probas[key][str(example_num)]['target_proba'] = target_proba.tolist()
    probas[key][str(example_num)]['target'] = target_idx.tolist()
    probas[key][str(example_num)]['model_pred_proba'] = model_pred_proba.tolist()
    probas[key][str(example_num)]['model_pred'] = model_pred.tolist()
    probas[key][str(example_num)]['avg_error'] = avg_error
    probas[key][str(example_num)]['avg_confidence'] = avg_confidence


def calculate_gradient_stats(key, example_num):
    """Calculate stats on given gradients
    """

    sample_grad = gradient_store[key][example_num]
    gradient_stats[key][str(example_num)] = {}

    grad_sum = 0
    grad_abs_sum = 0
    grad_sqrd_sum = 0
    num_param = 0

    for layer_name in sample_grad:
        layer_sum = sample_grad[layer_name].sum().item()
        layer_abs_sum = sample_grad[layer_name].abs().sum().item()
        layer_sqrd_sum = (sample_grad[layer_name] ** 2).sum().item()
        layer_param = sample_grad[layer_name].flatten().shape[0]
        # layer_mean = layer_sum / layer_param

        grad_sum += layer_sum
        grad_abs_sum += layer_abs_sum
        grad_sqrd_sum += layer_sqrd_sum
        num_param += layer_param

    mean_grad = grad_sum / num_param

    var_numerator = 0
    for layer_name in sample_grad:
        var_numerator += ((sample_grad[layer_name] - mean_grad) ** 2).sum().item()
        
    l1_norm = grad_abs_sum
    l2_norm = np.sqrt(grad_sqrd_sum)
    fisher = grad_sqrd_sum
    variance = var_numerator / num_param
        
    gradient_stats[key][str(example_num)]['l1_norm'] = l1_norm
    gradient_stats[key][str(example_num)]['l2_norm'] = l2_norm
    gradient_stats[key][str(example_num)]['fisher'] = fisher
    gradient_stats[key][str(example_num)]['variance'] = variance
    gradient_stats[key][str(example_num)]['num_param'] = num_param


def save_iteration_to_json(filename, key, data_to_store):

    if filename not in os.listdir("../results/per_sample_result/"):
        print("new file created")
        with open(f"../results/per_sample_result/{filename}","w") as f:
            dummy = {}
            json.dump(dummy, f)
            f.close()
    
    with open(f"../results/per_sample_result/{filename}","r") as f:
        temp = json.load(f)
        # if the key exists, add to iteration
        if key not in temp:
            temp[key] = {}
            temp[key][0] = data_to_store
        else:
            iters = list(temp[key].keys())
            max_iter = max([int(i) for i in iters])
            temp[key][max_iter + 1] = data_to_store

        with open(f"../results/per_sample_result/{filename}","w") as f:
            json.dump(temp, f, indent=4)
            print(f"Saved the new data to {filename}!")


def load_and_calculate_ppl(model, tokenizer, assistant_start_token, data, key, num_rep, model_prefix):
    """Iterate through data and calculate PPL on model generated output
    First, generate a response given a prompt
    Second, calculate the ppl on the response
    """

    loader = get_dataloader(data=data, tokenizer=tokenizer, assistant_start_token=assistant_start_token,batch_size=1, include_target_label=False)
    ppls[key] = {}
    losses[key] = {}

    for idx, batch0 in enumerate(loader):
        if batch0['input_ids'].shape[-1] >= 2048: # dataloader mapping needed just for this filtering
            continue

        ppls[key][str(idx)] = []
        losses[key][str(idx)] = []

        # Generate the text
        generated_text = generate_text(model, tokenizer, batch0, num_rep)
        encoded_text = [tokenizer(t, padding=True, add_special_tokens=False) for t in generated_text]
        
        # Get target perplexity
        mycollator = DataCollatorForCompletionOnlyLM(response_template=assistant_start_token, tokenizer=tokenizer)
        processed_sample = mycollator.torch_call(encoded_text) # pass the encoded text to the collator to mask out prompt/target
        batch1 = {k: v.to(device) for k,v in processed_sample.items()}
        batch1 = mask_non_answer_tokens(model_prefix=model_prefix, batch=batch1)

        # pass it through the model to get avg loss across generated samples
        with torch.no_grad():
            output = model(**batch1)
            neg_log_likelihood = output.loss
        
        ppls[key][str(idx)].append(torch.exp(neg_log_likelihood).item())
        # losses[key][str(idx)].append(neg_log_likelihood.item())

        if idx == 0:
            print(batch0) # get the batch
            print(generated_text) # this is model outputted answer

        del output
        del neg_log_likelihood


def load_and_calculate_gradient(model, tokenizer, data, key, num_sample, model_prefix, assistant_start_token,
                                file_postfix='v2',example_nums=None, cosine_batch_size=None):
    """Iterate through data and calculate gradient related stats
    """
    
    gradient_store[key] = {}
    gradient_stats[key] = {}
    losses[key] = {}
    probas[key] = {}
    
    loader = get_dataloader(data=data, tokenizer=tokenizer, assistant_start_token=assistant_start_token, batch_size=1, include_target_label=True)

    c=0    
    for idx_loader, batch in enumerate(loader):
        if batch['input_ids'].shape[-1] >= 2048:
            continue
        
        if example_nums is not None:
            idx = example_nums[idx_loader]
        else:
            idx = idx_loader

        gradient_store[key][idx] = {}
        losses[key][str(idx)] = {}
        probas[key][str(idx)] = {}
        
        # calculate the gradients
        batch = {k: v.to(device) for k,v in batch.items()}
        batch = mask_non_answer_tokens(model_prefix, batch)
        output = model(**batch)
        losses[key][str(idx)]['example_loss'] = output.loss.item()
        output.loss.backward()

        # store the gradients for gradient stats calculation
        for name, param in model.named_parameters():
            if param.grad is not None:
                gradient_store[key][idx][name] = param.grad.clone().detach()
        
        # once done, empty the gradients
        for p in model.parameters():
            if p.grad is not None:
                p.grad.zero_()

        if example_nums == None:
            # calculate the model proba
            calculate_model_proba(key=key, batch=batch, logits=output.logits, example_num=idx)
            # calculate the gradient store for this example
            calculate_gradient_stats(key, example_num=idx)

        # if cosine_batch_size specified, accumulate for batch_size many examples
        if cosine_batch_size and (c / cosine_batch_size <= 10):
            # for every X steps, calculate cosine similarity (for 10 iterations, due to computation time)
            if ((c+1) % cosine_batch_size == 0) and (c > 0):
                calculate_cosine_similarity(key)
                save_iteration_to_json(filename=f'{model_prefix}_cosine_similarity_metrics_{cosine_batch_size}_{file_postfix}.json', key=key, data_to_store=cosines[key])
                save_iteration_to_json(filename=f'{model_prefix}_dot_product_{cosine_batch_size}_{file_postfix}.json', key=key, data_to_store=dots[key])
                
                del gradient_store[key]
                # del mean_gradient_store[key]

                gc.collect()
                torch.cuda.empty_cache()
                gradient_store[key] = {} # reinitialize after emptying all gradients stored under this key
                # mean_gradient_store[key] = {}
        else:
            del gradient_store[key][idx]

        c += 1
        cutoff = num_sample if example_nums==None else cosine_batch_size * 10
        if c > cutoff:
            break

    del gradient_store[key]
    # del mean_gradient_store[key]
    del output
    gradient_store[key] = {}
    # mean_gradient_store[key] = {}
    gc.collect()
    torch.cuda.empty_cache()

    print("Done passing losses!")


def run(model_name, num_sample, use_peft=True, eval_on_model_output = False, data_split='train', subset_path=False, file_postfix="v2", num_rep=None):

    # the base model
    MODEL_PREFIX = None
    FILE_POSTFIX=file_postfix
    NUM_SAMPLE=num_sample
    COSINE_BATCH_SIZE= 32
    NUM_REP=num_rep
    
    model_cls = None
    model_args = {
        "model_name":model_name,
        "dataset_name":"deepmind/aqua_rat", # a random dataset name, because it cannot be empty
        "data_size":None,
        "task":"tokenized",
        "use_quantized":False,
        "use_peft":True,
        "peft_method":"lora",
        "density":None,
        "r":64,
        "seed":123
    }
    
    if 'mistral' in model_name.lower():
        model_cls = Mistral(**model_args)
        MODEL_PREFIX = 'mistral'
    elif 'llama' in model_name.lower():
        model_cls = Llama(**model_args)
        MODEL_PREFIX='llama'
    elif 'smollm' in model_name.lower():
        model_cls = SmolLM(**model_args)
        MODEL_PREFIX='smollm'
    else:
        print("Model class not supported!")
        return
    
    ASSISTANT_START_TOKEN=model_cls.assistant_start_token
        
    model, tokenizer = model_cls.get_model_and_tokenizer(use_flash_attention=True, use_safetensors=False, on_vector=False)
    if use_peft:
        peft_config = model_cls.get_peft_config()
        print(peft_config)
        model = get_peft_model(model, peft_config)
    model.to(device)
    model.eval()

    with open('../prompts/prompts_by_task_modified.yaml','r') as f:
        prompt = yaml.safe_load(f)

    tasks_to_exclude = ['cardiffnlp/tweet_topic_single','pacovaldez/stackoverflow-questions', 'nvidia/OpenMathInstruct-2','allenai/ai2_arc',
                        'maveriq/bigbenchhard', 'SahandNZ/cryptonews-articles-with-price-momentum-labels',
                        'saier/unarXive_imrad_clf', 'allenai/qasc', 'allenai/sciq','jpwahle/machine-paraphrase-dataset',
                        'masakhane/masakhanews', '/h/371/jayje/ft-intrinsic-dim/data/elementary_math_qa_question_only.json',
                        'openai/gsm8k','/h/371/jayje/ft-intrinsic-dim/data/qa_wikidata.json', '/h/371/jayje/ft-intrinsic-dim/data/disfl_qa.json', '/h/371/jayje/ft-intrinsic-dim/data/polish_sequence_labeling.json'
                        ]
    datasets_to_check = set(list(prompt.keys())) - set(tasks_to_exclude)

    print(datasets_to_check)

    for k in datasets_to_check:
        print(k)
        gc.collect()
        torch.cuda.empty_cache()
        
        dataset_name = k
        tasks = list(prompt[k].keys())
        
        for task in tasks:
            key = task.lower() if task is not None else dataset_name.split("/")[-1].replace('-','_').lower()
        
            # set model class params
            model_cls.dataset_name = dataset_name
            model_cls.task = task
            model_cls.task_prompt = model_cls.load_task_prompt()

            # load data
            data = model_cls.get_data(tokenizer=tokenizer, max_seq_len=2048, filter_long_seq=False, load_from_cache_file = True)
            
            # shuffle to keep the randomness consistent with actual training
            data[data_split] = data[data_split].shuffle(seed=123)
            if data[data_split].num_rows > NUM_SAMPLE:
                data[data_split] = data[data_split].select(range(NUM_SAMPLE))

            data_in = data[data_split]
            print(data_in['text'][0])
            print("num rows: ", data_in.num_rows)

            # restrict to check the gradients for
            subset_example_nums = None
            if subset_path is not None:
                with open(subset_path, 'r') as f:
                    subset_examples = json.load(f)
                
                if key not in subset_examples:
                    print(f"{key} is not in json file")
                    continue
                
                subset_example_nums = subset_examples[key]['example_num']
                data_in = data[data_split][subset_example_nums]

            # pass the loss and calculate gradient related stats
            if eval_on_model_output:
                load_and_calculate_ppl(model=model, tokenizer=tokenizer, data=data_in, key=key, assistant_start_token=ASSISTANT_START_TOKEN, num_rep=NUM_REP, model_prefix=MODEL_PREFIX)
                save_iteration_to_json(filename=f'{MODEL_PREFIX}_ppl_per_sample_model_gen_rep{NUM_REP}_{NUM_SAMPLE}_{FILE_POSTFIX}.json', key=key, data_to_store=ppls[key])
            else:
                load_and_calculate_gradient(model=model, tokenizer=tokenizer, data=data_in, key=key, num_sample=NUM_SAMPLE,
                                   model_prefix=MODEL_PREFIX, assistant_start_token=ASSISTANT_START_TOKEN,
                                   file_postfix=FILE_POSTFIX, example_nums=subset_example_nums, cosine_batch_size=COSINE_BATCH_SIZE
                                   )
                if subset_example_nums is None:
                    save_iteration_to_json(filename=f'{MODEL_PREFIX}_gradient_metrics_per_sample_{NUM_SAMPLE}_{FILE_POSTFIX}.json', key=key, data_to_store=gradient_stats[key])
                    save_iteration_to_json(filename=f'{MODEL_PREFIX}_model_proba_per_sample_{NUM_SAMPLE}_{FILE_POSTFIX}.json', key=key, data_to_store=probas[key])
                    save_iteration_to_json(filename=f'{MODEL_PREFIX}_model_loss_per_sample_{NUM_SAMPLE}_{FILE_POSTFIX}.json', key=key, data_to_store=losses[key])
            
            gc.collect()
            torch.cuda.empty_cache()

if __name__=='__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',
                    type=str,
                    help='Base model name')
    parser.add_argument('--num_sample',
                    default=2500,
                    type=int,
                    help='total # samples to process')
    parser.add_argument('--data_split',
                    default='train',
                    type=str,
                    help='which data subset to calculate the metrics for'
                    )
    parser.add_argument('--subset_path',
                    default=None,
                    type=str,
                    help='use if the metrics are calculated on specified examples in the subset path'
                    )
    parser.add_argument('--file_postfix',
                    type=str,
                    help='specify the file tag; e.g. v2, v2_wrong, etc.'
                    )
    parser.add_argument('--num_rep',
                    default=None,
                    type=int,
                    help='specify if generating model output for num_rep many times'
                    )
    parser.add_argument('--use_peft',
                    action="store_true",
                    help="if downsizing the model params for easier gradient calculation"
                    )
    parser.add_argument('--eval_on_model_output',
                    action="store_true",
                    help="if calculating ppl on model's own answer"
                    )
    args = parser.parse_args()

    run(model_name=args.model_name,
        num_sample=args.num_sample,
        use_peft=args.use_peft,
        eval_on_model_output=args.eval_on_model_output,
        data_split=args.data_split,
        subset_path=args.subset_path,
        file_postfix=args.file_postfix,
        num_rep=args.num_rep
       )
