import argparse
import torch
import numpy as np
import os
import glob
import json
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import PeftModel, AutoPeftModelForCausalLM, PeftConfig
from accelerate import Accelerator
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from math import exp
import omegaconf
from tqdm import tqdm
from collections import defaultdict
from utils import get_model_identifiers_from_yaml

from data_module import TextDatasetQA, get_batch_loss, convert_raw_data_to_model_format, custom_data_collator_with_qa


class GeneratedDataset(Dataset):
    def __init__(
        self, 
        questions,
        answers, 
        tokenizer, 
        model_family,
        max_length=512
    ):
        super(GeneratedDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.questions = questions
        self.answers = answers
        self.indices = list(range(len(questions)))
        self.model_configs = get_model_identifiers_from_yaml(model_family)

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        answers = self.answers[idx]
        indices = self.indices[idx] #self.data[idx]['index']
        if isinstance(answers, str):
            answers = [answers]

        pad_input_ids_list = []
        label_list = []
        pad_attention_mask_list = []

        for answer in answers:
            converted_data = convert_raw_data_to_model_format(
                self.tokenizer, 
                self.max_length, 
                question, 
                answer, 
                self.model_configs
            )
            pad_input_ids_list.append(converted_data[0])
            label_list.append(converted_data[1])
            pad_attention_mask_list.append(converted_data[2])

        return torch.stack(pad_input_ids_list).squeeze(),\
                torch.stack(label_list).squeeze(),\
                torch.stack(pad_attention_mask_list).squeeze(),\
                torch.tensor(indices),\
                question,\
                answer

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--split", type=str, required=True)
parser.add_argument("--model_id", type=str, default="NousResearch/Meta-Llama-3.1-8B-Instruct")
parser.add_argument("--model_family", type=str, default="llama3.1-8b")
parser.add_argument("--test_model", type=str, required=True)
parser.add_argument("--gen_model", type=str, required=True)
parser.add_argument("--tag", type=str, required=True)
args = parser.parse_args()

# === arguments ===
seed = args.seed
batch_size = args.batch_size
split = args.split # forget01_perturbed
model_id = args.model_id # "NousResearch/Meta-Llama-3.1-8B-Instruct"
model_family = args.model_family # "llama3.1-8b"
test_model = args.test_model
gen_model = args.gen_model
tag = args.tag

# === initialize accelerator ===
accelerator = Accelerator()
device = accelerator.device
set_seed(seed)

# === load unlearning arguments ===
cfg = omegaconf.OmegaConf.load('config/eval_perturb.yaml')
cfg.split_list[-1] = f'{split}_perturbed'
cfg.model_family = model_family
cfg.model_path = test_model
cfg.batch_size = batch_size

model_cfg = omegaconf.OmegaConf.load('config/model_config.yaml')
q_start_tag = model_cfg[cfg.model_family]['question_start_tag']
q_end_tag = model_cfg[cfg.model_family]['question_end_tag']

### Load Model
if "grad_ascent" in test_model or "grad_diff" in test_model or "idk" in test_model or "npo_grad_diff" in test_model:
    config = PeftConfig.from_pretrained(cfg.model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoPeftModelForCausalLM.from_pretrained(
        cfg.model_path, 
        config=config, 
        use_flash_attention_2=True, 
        torch_dtype=torch.bfloat16, 
        trust_remote_code=True, 
        device_map="cuda:0"
    )
    model = model.merge_and_unload()
else:
    config = AutoConfig.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_path, 
        config=config, 
        use_flash_attention_2=True, 
        torch_dtype=torch.bfloat16, 
        trust_remote_code=True, 
        device_map="cuda:0"
    )
model.eval()

def preprocess_log(generated_text, eval_task):
    counts = {}
    texts = generated_text[f'{eval_task}.json']
    for question, answers in texts.items():
        filtered_question = question.strip(q_start_tag).strip(q_end_tag)
        counts[filtered_question] = defaultdict(int)
        for answer in answers:
            counts[filtered_question][answer] += 1
    questions = []
    answers = []
    for question in counts:
        for answer in counts[question]:
            questions.append(question)
            answers.append(answer)
    return questions, answers, counts

@torch.no_grad()
def get_scores(model, eval_dataloader, counts, desc="DEFAULT"):
    eval_logs = {}
    eval_logs['avg_gt_loss'] = {}
    eval_logs['gt_loss'] = {}
    eval_logs['num_token_gt'] = {}
    eval_logs['counts'] = {}
    eval_logs['questions'] = {}
    eval_logs['answers'] = {}
    all_indices = []
    for batch in tqdm(eval_dataloader, desc=desc):
        input_ids, labels, attention_mask, indices, questions, answers = batch
        
        all_indices.extend(indices.cpu().numpy().tolist())
        batch = {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        #send to device
        for k, v in batch.items():
            batch[k] = v.to(model.device)

        with torch.no_grad():
            outputs = model(**batch)
            
        gt_loss = get_batch_loss(outputs.logits, batch['labels']).float()
        num_token_gt = (batch['labels']!=-100).sum(-1).float()
        gt_loss_per_token = (gt_loss/num_token_gt).float()
        
        eval_logs['avg_gt_loss'].update(dict(zip(indices.cpu().numpy().tolist(), gt_loss_per_token.cpu().numpy().tolist())))
        eval_logs['gt_loss'].update(dict(zip(indices.cpu().numpy().tolist(), gt_loss.cpu().numpy().tolist())))
        eval_logs['num_token_gt'].update(dict(zip(indices.cpu().numpy().tolist(), num_token_gt.cpu().numpy().tolist())))
        eval_logs['counts'].update(
            dict(zip(indices.cpu().numpy().tolist(), [counts[question][answer] for question, answer in zip(questions, answers)]))
        )
        eval_logs['questions'].update(
            dict(zip(indices.cpu().numpy().tolist(), questions))
        )
        eval_logs['answers'].update(
            dict(zip(indices.cpu().numpy().tolist(), answers))
        )

    return eval_logs

aggregated_eval_logs = {}

filename = os.path.join(gen_model, f"{split}_generated_samples_hf_final.json")
with open(filename, "r") as f:
    saved_eval_logs = json.load(f)

for i, (folder, data_split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(
        cfg.data_path, 
        cfg.split_list, 
        cfg.question_key, 
        cfg.answer_key, 
        cfg.eval_task, 
        cfg.base_answer_key, 
        cfg.perturbed_answer_key
    )):
    print(data_split)
    if not eval_task == "eval_log_forget":
        continue

    ##############################################################
    # Gather Forget Model Likelihood Statistics on Retain Outputs
    ##############################################################
    questions, answers, counts = preprocess_log(saved_eval_logs, eval_task)
    dataset = GeneratedDataset(
        questions,
        answers,
        tokenizer,
        model_family,
        max_length=512
    )
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=cfg.batch_size, 
        collate_fn=custom_data_collator_with_qa
    )
    aggregated_eval_logs[f'{eval_task}.json'] = get_scores(
        model, 
        dataloader,
        counts,
        desc=eval_task
    )

# pretty write json to f
with open(os.path.join(test_model, f"f1_{split}_vs_{tag}_hf_final.json"), "w") as f:
    json.dump(aggregated_eval_logs, f, indent=4)