import argparse
import torch
import numpy as np
import os
import glob
import json
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import PeftModel
from accelerate import Accelerator
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from math import exp
import omegaconf
from tqdm import tqdm
from utils import get_model_identifiers_from_yaml

from data_module import TextDatasetQA, get_batch_loss, convert_raw_data_to_model_format, custom_data_collator_with_indices
from evaluate_util import get_dataloader, eval_bleu, eval_rouge_recall, eval_perturbation_ratio, run_generation
from aggregate_eval_stat import get_forget_quality

from huggingface_hub import login
hf_token="INSERT TOKEN"
login(hf_token)


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--sample_size", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--split", type=str, required=True)
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--model_id", type=str, default="NousResearch/Meta-Llama-3.1-8B-Instruct")
parser.add_argument("--model_family", type=str, default="llama3.1-8b")
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--adapter_path", type=str, default=None)
args = parser.parse_args()

seed = args.seed
sample_size = args.sample_size
batch_size = args.batch_size
max_length = args.max_length
split = args.split # forget01_perturbed
num_samples = args.num_samples
model_id = args.model_id # "NousResearch/Meta-Llama-3.1-8B-Instruct"
model_family = args.model_family # "llama3.1-8b"
base_model_path = args.base_model_path
adapter_path = args.adapter_path

accelerator = Accelerator()
device = accelerator.device
set_seed(seed)


cfg = omegaconf.OmegaConf.load('config/eval_perturb.yaml')
cfg.split_list[-1] = f'{split}_perturbed'
cfg.batch_size = batch_size
cfg.model_id = model_id
cfg.model_family = model_family
cfg.model_path = adapter_path

config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
"""
model = AutoModelForCausalLM.from_pretrained(
    cfg.model_path,
    config=config,
    use_flash_attention_2=True, 
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True, 
    device_map="cuda:0"
)
"""
model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    config=config,
    use_flash_attention_2=True, 
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True, 
    device_map="cuda:0"
)
model = PeftModel.from_pretrained(
    model,
    adapter_path,
    torch_device="cuda:0"
)

model = model.merge_and_unload()
model.eval()

def get_qa(eval_logs):

    answers = {}
    #count = 0
    texts = eval_logs['generated_text']
    for idx in texts:
        for idx2 in range(len(texts[idx][0])):
            #questions[count] = texts[idx][0][idx2].strip("[INST] ")
            #questions[count] = texts[idx][0][idx2]
            question = texts[idx][0][idx2]
            answer = texts[idx][1][idx2]
            if question in answers:
                answers[question].append(answer)
            else:
                answers[question] = [answer]
            #count += 1
    return answers

@torch.no_grad()
def gather_statistics(cfg, model, tokenizer, eval_dataloader, base_eval_dataloader, perturb_dataloader, desc="DEFAULT", sample_count=1):
    eval_logs = {}
    gen_outputs = []
    ground_truths = []
    input_strings = []
    all_indices = []
    for batch in tqdm(eval_dataloader, desc=desc):
        input_ids, labels, attention_mask, indices = batch
        
        all_indices.extend(indices.cpu().numpy().tolist())
        batch = {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        #send to device
        for k, v in batch.items():
            batch[k] = v.to(model.device)

        bsz = len(indices)
        with torch.no_grad():
            outputs = model(**batch)
            batch_input_string = [[] for _ in range(bsz)]
            batch_gen_output = [[] for _ in range(bsz)]
            batch_gt = [[] for _ in range(bsz)]
            gen_batch = {}
            gen_batch['input_ids'] = batch['input_ids'].repeat(25, 1)
            gen_batch['labels'] = batch['labels'].repeat(25, 1)
            gen_batch['attention_mask'] = batch['attention_mask'].repeat(25, 1)
            for sample_idx in range(int(sample_count / 25)):
                input_string, gen_output, gt = run_generation(
                    cfg, gen_batch, model,
                    tokenizer=tokenizer,
                    do_sample=True
                )
                for b_idx in range(bsz):
                    batch_input_string[b_idx].extend(input_string[b_idx::bsz])
                    batch_gen_output[b_idx].extend(gen_output[b_idx::bsz])
                    batch_gt[b_idx].extend(gt[b_idx::bsz])
                    #batch_input_string[b_idx].append(input_string[b_idx + r_idx * 25])
                    #batch_gen_output[b_idx].append(gen_output[b_idx + r_idx * 25])
                    #batch_gt[b_idx].append(gt[b_idx + r_idx * 25])
            gen_outputs.extend(gen_output)
            ground_truths.extend(gt)
            input_strings.extend(input_string)
            
        gt_loss = get_batch_loss(outputs.logits, batch['labels']).float()
        num_token_gt = (batch['labels']!=-100).sum(-1)
        gt_loss_per_token = (gt_loss/num_token_gt).float()

        if 'avg_gt_loss' not in eval_logs:
            eval_logs['avg_gt_loss'] = {}
        if 'gt_loss' not in eval_logs:
            eval_logs['gt_loss'] = {}
        if 'num_token_gt' not in eval_logs:
            eval_logs['num_token_gt'] = {}
        if 'generated_text' not in eval_logs:
            eval_logs['generated_text'] = {}
        eval_logs['avg_gt_loss'].update(dict(zip(indices.cpu().numpy().tolist(), gt_loss_per_token.cpu().numpy().tolist())))
        eval_logs['gt_loss'].update(dict(zip(indices.cpu().numpy().tolist(), gt_loss.cpu().numpy().tolist())))
        eval_logs['num_token_gt'].update(dict(zip(indices.cpu().numpy().tolist(), num_token_gt.cpu().numpy().tolist())))
        #eval_logs['generated_text'].update(dict(zip(indices.cpu().numpy().tolist(), zip(input_string, gen_output, gt))))
        eval_logs['generated_text'].update(dict(zip(indices.cpu().numpy().tolist(), zip(batch_input_string, batch_gen_output, batch_gt))))

    #eval_logs.update(eval_bleu(gen_outputs, ground_truths))
    #eval_logs.update(eval_rouge_recall(gen_outputs, ground_truths, all_indices))
    #eval_logs.update(eval_perturbation_ratio(base_eval_dataloader, perturb_dataloader, model))
    return eval_logs

aggregated_eval_logs = {}

for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(
        cfg.data_path, 
        cfg.split_list, 
        cfg.question_key, 
        cfg.answer_key, 
        cfg.eval_task, 
        cfg.base_answer_key, 
        cfg.perturbed_answer_key
    )):
    print(split, eval_task)

    if not eval_task == "eval_log_forget":
        continue

    ### Load dataset
    eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(
        cfg,
        eval_task,
        tokenizer,
        folder,
        split,
        question_key,
        answer_key,
        base_answer_key,
        perturbed_answer_key
    )

    ##################################################
    # Gather Retain Model Generation Statistics
    ##################################################
    eval_logs = gather_statistics(
        cfg,
        model, 
        tokenizer, 
        eval_dataloader, 
        base_eval_dataloader,  
        perturb_dataloader,
        desc=eval_task,
        sample_count=num_samples
    )
    aggregated_eval_logs[f'{eval_task}.json'] = get_qa(eval_logs)
    

# pretty write json to f
with open(os.path.join(adapter_path, f"{args.split}_generated_samples_hf_final.json"), "w") as f: # "./test_f_stats.json", "w") as f:
    json.dump(aggregated_eval_logs, f, indent=4)