import gc
import os
import torch
import yaml
import json
import numpy as np

from scripts.llama_full import Llama
import torch.nn.functional as F
from torch.utils.data import DataLoader
# from trl import DataCollatorForSeq2Seq
from peft import get_peft_model
from trl import DataCollatorForCompletionOnlyLM
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling, GenerationConfig

device='cuda'
ppls = {}
losses = {}

def get_dataloader(data, tokenizer, batch_size=1, prompt_ppl=False):
    
    mycollator = None
    if prompt_ppl:
        mycollator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    else:
        mycollator = DataCollatorForCompletionOnlyLM(response_template="<|start_header_id|>assistant<|end_header_id|>",
                                                     tokenizer=tokenizer)
    
    sample_ls = [tokenizer(data[i]['text'], padding=True, truncation=True, add_special_tokens=True) for i in range(len(data))]
    loader = DataLoader(
            sample_ls,
            shuffle=False,
            collate_fn=mycollator,
            batch_size=batch_size
        )
    
    return loader


def generate_text(model, tokenizer, batch, num_rep):

    org_text = tokenizer.decode(batch['input_ids'][0])
    gen_input = org_text.split("<|start_header_id|>assistant<|end_header_id|>")[0] + "<|start_header_id|>assistant<|end_header_id|>"
    gen_input_tok = tokenizer(gen_input, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True)

    generation_config = GenerationConfig(
                                        temperature=0.5,
                                        max_new_tokens=30,
                                        pad_token_id=tokenizer.eos_token_id,
                                        num_beam=4,
                                        do_sample=True
                                        )
    batch = {k: v.repeat(num_rep,1).to(device) for k, v in gen_input_tok.items()}
    out = model.generate(**batch, generation_config=generation_config)
    generated_text = tokenizer.batch_decode(out)
    
    del batch
    del out

    return generated_text


def calculate_ppl_target(model, tokenizer, data, key, num_rep):

    loader = get_dataloader(data, tokenizer=tokenizer, batch_size=1, prompt_ppl=False)
    ppls[key] = {}
    losses[key] = {}

    for idx, batch0 in enumerate(loader):
        if batch0['input_ids'].shape[-1] >= 2048: # dataloader mapping needed just for this filtering
            continue

        # Generate the text
        generated_text = generate_text(model, tokenizer, batch0, num_rep)
        encoded_text = [tokenizer(t, padding=True, truncation=True, add_special_tokens=False) for t in generated_text]
        
        # Get target perplexity
        mycollator = DataCollatorForCompletionOnlyLM(response_template="<|start_header_id|>assistant<|end_header_id|>", tokenizer=tokenizer)

        ppls[key][str(idx)] = []
        losses[key][str(idx)] = []
        
        processed_sample = mycollator.torch_call(encoded_text)
        batch1 = {k: v.to(device) for k,v in processed_sample.items()}

        # pass it through the model to get avg loss across generated samples
        with torch.no_grad():
            output = model(**batch1)
            neg_log_likelihood = output.loss
        
        ppls[key][str(idx)].append(torch.exp(neg_log_likelihood).item())
        losses[key][str(idx)].append(neg_log_likelihood.item())

        if idx == 0:
            print(batch0) # get the batch
            print(generated_text) # this is model outputted answer

        del output
        del neg_log_likelihood


def calculate_ppl(model, tokenizer, data, key):
    
    loader = get_dataloader(data, tokenizer=tokenizer, batch_size=1, prompt_ppl=True)
    end_of_header = 128007
    ppls[key] = {}

    for idx, batch in enumerate(loader):
        if batch['input_ids'].shape[-1] >= 2048:
            continue
        if idx == 0:
            print(batch)
        # get the batch        
        batch = {k: v.to(device) for k,v in batch.items()}
        
        # remove the "<begin_of_text>" in front
        batch['input_ids'] = batch['input_ids'][:,1:] 
        batch['attention_mask'] = batch['attention_mask'][:,1:]
        batch['labels'] = batch['labels'][:,1:]
        
        # get index for the end of prompt
        prompt_end_index = torch.where(batch['input_ids'] == end_of_header)[-1][-1] # get the indices, and take last 
        batch['labels'][:, prompt_end_index+1:] = -100 # set the prompt response to -100, since only interested in prompt ppl
        
        # calculate the gradients
        with torch.no_grad():
            output = model(**batch)
            neg_log_likelihood = output.loss
        
        ppls[key][str(idx)] = torch.exp(neg_log_likelihood).item()

        del output
        del neg_log_likelihood


def save_iteration_to_json(filename, key, data_to_store):

    if filename not in os.listdir("../results/"):
        print("new file created")
        with open(f"../results/{filename}","w") as f:
            dummy = {}
            json.dump(dummy, f)
            f.close()
    
    with open(f"../results/{filename}","r") as f:
        temp = json.load(f)
        # if the key exists, add to iteration
        if key not in temp:
            temp[key] = {}
            temp[key][0] = data_to_store
        else:
            iters = list(temp[key].keys())
            max_iter = max([int(i) for i in iters])
            temp[key][max_iter + 1] = data_to_store

        with open(f"../results/{filename}","w") as f:
            json.dump(temp, f, indent=4)
            print(f"Saved the new data to {filename}!")


def run(NUM_SAMPLE=2500, NUM_REP=1, prompt_ppl=False):

    # the base model
    model_cls = Llama(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    dataset_name="deepmind/aqua_rat", # a random dataset name, because it cannot be empty
                    data_size=None,
                    task="tokenized",
                    use_quantized=False,
                    use_peft=False,
                    peft_method="",
                    density=None,
                    r=None,
                    seed = 123
                    )
    model, tokenizer = model_cls.get_model_and_tokenizer(use_flash_attention=True, use_safetensors=False, on_vector=False)
    # peft_config = model_cls.get_peft_config()
    # model = get_peft_model(model, peft_config)
    model.to(device)
    model.eval()

    with open('../prompts/prompts_by_task_modified.yaml','r') as f:
        prompt = yaml.safe_load(f)

    tasks_to_exclude = ['cardiffnlp/tweet_topic_single','pacovaldez/stackoverflow-questions', 'nvidia/OpenMathInstruct-2','allenai/ai2_arc',
                        'maveriq/bigbenchhard', 'SahandNZ/cryptonews-articles-with-price-momentum-labels',
                        'saier/unarXive_imrad_clf', 'allenai/qasc', 'allenai/sciq','jpwahle/machine-paraphrase-dataset',
                        'masakhane/masakhanews',
                        '/h/371/jayje/ft-intrinsic-dim/data/elementary_math_qa_question_only.json'
                        ]
    datasets_to_check = set(list(prompt.keys())) - set(tasks_to_exclude)
    print(datasets_to_check)

    for k in datasets_to_check:
        print(k)
        gc.collect()
        torch.cuda.empty_cache()
        
        dataset_name = k
        tasks = list(prompt[k].keys())
        
        for task in tasks:
            key = task if task is not None else dataset_name.split("/")[-1].replace('-','_').lower()
        
            # set model class params
            model_cls.dataset_name = dataset_name
            model_cls.task = task
            model_cls.task_prompt = model_cls.load_task_prompt()

            # load data
            data = model_cls.get_data(tokenizer=tokenizer, max_seq_len=2048, filter_long_seq=False, load_from_cache_file = True)
            
            # shuffle to keep the randomness consistent with actual training
            data['train'] = data['train'].shuffle(seed=123)
            if data['train'].num_rows > NUM_SAMPLE:
                data['train'] = data['train'].select(range(NUM_SAMPLE))

            print(data['train']['text'][0])
            print("num rows: ", data['train'].num_rows)

            # pass the loss and calculate gradient related stats
            if prompt_ppl:
                calculate_ppl(model=model, tokenizer=tokenizer, data=data['train'], key=key)
            else:
                calculate_ppl_target(model=model, tokenizer=tokenizer, data=data['train'], key=key, num_rep=NUM_REP)

            save_iteration_to_json(filename=f'ppl_per_sample_model_gen_rep{NUM_REP}_{NUM_SAMPLE}_v2.json', key=key, data_to_store=ppls[key])
            save_iteration_to_json(filename=f'ppl_loss_per_sample_model_gen_rep{NUM_REP}_{NUM_SAMPLE}_v2.json', key=key, data_to_store=losses[key])
            
            gc.collect()
            torch.cuda.empty_cache()

if __name__=='__main__':
    run(NUM_SAMPLE=2500, NUM_REP=8, prompt_ppl=False)
