####################################################################################################
# calculate data attribution (influence of every data sample) before unlearning 
# reference
# 1. library github: https://github.com/MadryLab/trak
# 2. tutorials & quick start for TRAK: https://trak.readthedocs.io/en/latest/
# 3. API documentation: https://trak.readthedocs.io/en/latest/trak.html
# 4. example for TRAK on BERT (classification task): https://trak.readthedocs.io/en/latest/bert.html
####################################################################################################


import os
import re
import shutil
from pathlib import Path



import torch
import hydra
import transformers
from omegaconf import OmegaConf
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed
from peft import LoraConfig, get_peft_model, PeftModel

from data_module import TextForgetDatasetQA, TextForgetDatasetDPOQA
from dataloader import CustomTrainerForgetting, custom_data_collator_forget
from utils import get_model_identifiers_from_yaml, find_all_linear_names, print_trainable_parameters
import trak
from trak import TRAKer
from tqdm import tqdm

from torch.utils.data import DataLoader

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import datasets
from utils import get_model_identifiers_from_yaml, add_dataset_index

from data_module import convert_raw_data_to_model_format
from trak.modelout_functions import AbstractModelOutput
from trak_utils import Projector

class TextForgetDatasetQA(Dataset):
    def __init__(self, data_path, tokenizer, model_family,  max_length=512, split = "forget10", loss_type="idk"):
        super(TextForgetDatasetQA, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.forget_data = datasets.load_dataset(data_path, split)["train"]
        retain_split = "retain" + str(100 - int(split.replace("forget", ""))).zfill(2)
        self.retain_data =datasets.load_dataset(data_path, retain_split)["train"]
        self.model_configs = get_model_identifiers_from_yaml(model_family)
        self.loss_type = loss_type

        if self.loss_type == "idk":
            self.split1, self.split2 = "idk", "retain"
            self.idontknowfile = "data/idontknow.jsonl"
            self.idk = open(self.idontknowfile, "r").readlines()
        else:
            self.split1, self.split2 = "forget", "retain"

    def __len__(self):
        return len(self.forget_data)

    def __getitem__(self, idx):
        rets = []
        for data_type in [self.split1, self.split2]:
            #use questions from forget set if split is idk or forget
            data = self.retain_data if data_type == "retain" else self.forget_data
            idx = idx if data_type != "retain" else (idx + torch.randint(0, len(self.retain_data), (1,)).item()) % len(self.retain_data)
            question = data[idx]['question']
            answer = data[idx]['answer']

            if data_type == "idk":
                #get a random answer position from idk
                rand_pos = torch.randint(0, len(self.idk), (1,)).item()
                answer = self.idk[rand_pos].strip()
                
            converted_data = convert_raw_data_to_model_format(self.tokenizer, self.max_length, question, answer, self.model_configs)
            rets.append(converted_data)
        return rets
    





#### define model and load checkpoints
def load_model(model_family, model_path):
    # load $model_family$ model from $model_path$ checkpoint
    model_cfg = get_model_identifiers_from_yaml(model_family)
    model_id = model_cfg["hf_key"]
    config = AutoConfig.from_pretrained(model_id)
    print("Loading from checkpoint")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        config=config,
        use_flash_attention_2=(model_cfg["flash_attention2"] == "true"),
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        #device_map=device_map  # Optionally add the device map here as well.
    ).cuda()
    
    # for name, param in model.named_parameters():
    #     print(name)
        
    
    #### define training samples and target samples
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # Ensure the pad token is set (defaulting to eos_token if necessary)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    
    return model, tokenizer



def load_data(data_path, split, model_family, tokenizer, forget_loss = "grad_ascent"):
    max_length = 500
    # if forget_loss == "dpo":
    #     torch_format_dataset = TextForgetDatasetDPOQA(
    #         data_path,
    #         tokenizer=tokenizer,
    #         model_family=model_family,
    #         max_length=max_length,
    #         split=split,
    #     )
    # else:
    #     torch_format_dataset = TextForgetDatasetQA(
    #         data_path,
    #         tokenizer=tokenizer,
    #         model_family=model_family,
    #         max_length=max_length,
    #         split=split,
    #         loss_type=forget_loss,
    #     )
    
    ds = TextForgetDatasetQA(
            data_path,
            tokenizer=tokenizer,
            model_family=model_family,
            max_length=max_length,
            split=split,
            loss_type=forget_loss,
        )   
    ds_forget, ds_retain = ds.forget_data, ds.retain_data
       
    # ds_train = ds_train.select(range(TRAIN_SET_SIZE))
    # ds_val = ds_val.select(range(VAL_SET_SIZE))
    loader_forget = DataLoader(ds_forget, batch_size=batch_size, shuffle=False)
    loader_retain = DataLoader(ds_retain, batch_size=batch_size, shuffle=False)
    return loader_forget, loader_retain
    
def my_convert_raw_data_to_model_format(tokenizer, max_length,  question, answer, model_configs):
    question_start_token, question_end_token, answer_token = model_configs['question_start_tag'], model_configs['question_end_tag'], model_configs['answer_tag']
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True,
        return_token_type_ids=True  
    )
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_token_type_ids = encoded['token_type_ids'] + [0] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    
    
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

    #change label to -100 for question tokens
    for i in range(num_question_tokens): label[i] = -100

    return torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask), torch.tensor(pad_token_type_ids)
    

def process_batch(batch, batch_size, tokenizer, model_family, max_length):
    cvt_bch = {"input_ids":[],
                "labels":[],
                "attention_mask":[],
                "token_type_ids":[]}
    for i in range(batch_size):
        question = batch['question'][i]
        answer = batch['answer'][i]
        model_configs = get_model_identifiers_from_yaml(model_family)
        converted_data = my_convert_raw_data_to_model_format(tokenizer, max_length, question, answer, model_configs)
        input_ids, labels, attention_mask, token_type_ids = converted_data
        
        cvt_bch['input_ids'].append(input_ids)
        cvt_bch['labels'].append(labels)
        cvt_bch['attention_mask'].append(attention_mask)
        cvt_bch['token_type_ids'].append(token_type_ids)
    
    return torch.stack(cvt_bch["input_ids"], dim=0),\
            torch.stack(cvt_bch['labels'], dim = 0),\
            torch.stack(cvt_bch['attention_mask'], dim = 0),\
            torch.stack(cvt_bch['token_type_ids'], dim = 0)



# # see source code of TRAK custom abstractmodel: https://trak.readthedocs.io/en/latest/_modules/trak/modelout_functions.html#AbstractModelOutput
# class TextGeneration():
#     def __init__(self):
#         super().__init__()
#         pass
    
#     # def get_output(
#     #                 model,
#     #                 weights,
#     #                 buffers,
#     #                 input_id,
#     #                 token_type_id,
#     #                 attention_mask,
#     #                 label,
#     #                 ):
#     #     kw_inputs = {
#     #         "input_ids": input_id.unsqueeze(0),
#     #         "token_type_ids": token_type_id.unsqueeze(0),
#     #         "attention_mask": attention_mask.unsqueeze(0),
#     #     }

#     #     logits = torch.func.functional_call(
#     #         model, (weights, buffers), args=(), kwargs=kw_inputs
#     #     )
#     #     bindex = torch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
#     #     logits_correct = logits[bindex, label.unsqueeze(0)]

#     #     cloned_logits = logits.clone()
#     #     cloned_logits[bindex, label.unsqueeze(0)] = torch.tensor(
#     #         -torch.inf, device=logits.device, dtype=logits.dtype
#     #     )

#     #     margins = logits_correct - cloned_logits.logsumexp(dim=-1)
#     #     return margins.sum()

    
    
#     def get_output(self, 
#                     model,
#                     weights,
#                     buffers,
#                     input_id,
#                     token_type_id,
#                     attention_mask,
#                     label):
        
#         print("calling get_output")
#         # kw_inputs = {
#         #             "input_ids": input_id.unsqueeze(0),
#         #             "token_type_ids": token_type_id.unsqueeze(0),
#         #             "attention_mask": attention_mask.unsqueeze(0),
#         #             "label": label.unsqueeze(0)
#         #             }
#         kw_inputs = {
#             "input_ids": input_id.unsqueeze(0),
#             "token_type_ids": token_type_id.unsqueeze(0),
#             "attention_mask": attention_mask.unsqueeze(0),
#         }

#         logits = torch.func.functional_call(
#             model, (weights, buffers), args=(), kwargs=kw_inputs
#         )
        
        
#         wts = {}
#         bfs = {}
#         for idx, (name, param) in enumerate(model.named_parameters()):
#             wts[name] = param
#             bfs[name] = param
        
#         logits = torch.func.functional_call(
#                                             model, {"weights": wts, "buffers": bfs}, args=(), kwargs=kw_inputs
#                                             )
#         logits = torch.func.functional_call(
#                                             model, ("weights": weights, "buffers": buffers), args=(), kwargs=kw_inputs
#                                             )
        
#     def get_output(self, 
#                     model,
#                     input_id,
#                     attention_mask,
#                     label):
        
#         output = model(input_ids = input_id, labels=label, attention_mask=attention_mask)
#         loss -=  output.loss
#         return loss
    
    
#     def get_out_to_loss_grad(self, model, batch):
#         loss = 0
#         input_ids, labels, attention_mask, _ = batch
#         output = model(input_ids = input_ids, labels=labels, attention_mask=attention_mask)
#         loss -=  output.loss
#         grads = torch.autograd.grad(loss, model.parameters())
#         return grads



    


if __name__ == "__main__":
    #### define args
    model_path = "./checkpoints/ft_epoch2_lr1e-05_phi_forget01_wd0.01/checkpoint-20"
    model_family = "phi"
    
    data_path = "locuslab/TOFU"
    split = "forget01"
    batch_size = 4
    max_length = 500
    proj_dim = 1


    model, tokenizer = load_model(model_family, model_path)
    loader_train, loader_val = load_data(data_path, split, model_family, tokenizer, forget_loss = "grad_ascent")

    traker = TRAKer(
                    model=model,
                    # task="TextGeneration",  # we define a new task class for text generation
                    task = "text_classification", # we do not use the task arg anymore. just put an input here to make the arg valid.
                    train_set_size = batch_size,
                    # save_dir="./trak/",
                    device="cuda",
                    # proj_dim=1024,
                    proj_dim = proj_dim
                    )
    
    traker.load_checkpoint(model.state_dict(), model_id=0)

    from trak_utils import ProjectionType
    projector = Projector(
                grad_dim= traker.num_params_for_grad,
                proj_dim= traker.proj_dim,
                seed= traker.proj_seed,
                proj_type = ProjectionType.normal,
                max_batch_size = 32, # default value in traker initalization
                dtype=traker.dtype,
                device=traker.device,
            )


    for batch in tqdm(loader_train, desc='Featurizing..'):
        # process batch into compatible form for TRAKer TextClassificationModelOutput
        batch = process_batch(batch = batch,\
                                batch_size = batch_size,\
                                tokenizer = tokenizer,\
                                model_family= model_family,\
                                max_length = max_length)
        batch = [x.to("cuda") for x in batch]
        #==================================================================
        # orginal code by TRAK for calculating grads & projecting grads
        #==================================================================
        grads = traker.gradient_computer.compute_per_sample_grad(batch=batch)
        # grads = traker.projector.project(grads, model_id=traker.saver.current_model_id)
        # traker.featurize(batch=batch, num_samples=batch[0].shape[0])
        
        #==================================================================
        # manual implementation for calculating grads & projecting grads
        #==================================================================
        
        # taking the gradient wrt weights (second argument of get_output, hence argnums=1)
        grads_loss = torch.func.grad(
            self.modelout_fn.get_output, has_aux=False, argnums=1
        )

        # map over batch dimensions (hence 0 for each batch dimension, and None for model params)
        grads = torch.func.vmap(
            grads_loss,
            in_dims=(None, None, None, *([0] * len(batch))),
            randomness="different",
        )(self.model, self.func_weights, self.func_buffers, *batch)

        
        # def per_sample_gradients(model, inputs, targets, loss_fn):
        # """
        # Compute per-sample gradients for each layer.
        
        # Args:
        #     model: PyTorch model.
        #     inputs: Input tensor of shape (batch_size, input_dim).
        #     targets: Target tensor of shape (batch_size,).
        #     loss_fn: Loss function.
        
        # Returns:
        #     A dictionary where keys are layer names and values are tensors of shape (batch_size, layer_size).
        # """
        # model.zero_grad()
        # outputs = model(inputs)
        # loss = loss_fn(outputs, targets)
        
        # grads_dict = {}
        # for name, param in model.named_parameters():
        #     param_grad = torch.autograd.grad(loss, param, retain_graph=True, grad_outputs=torch.ones_like(loss))
            # grads_dict[name] = param_grad[0].detach().clone()

            
        

        # input_ids, labels, attention_mask, _ = batch
        # output = model(input_ids = input_ids, labels=labels, attention_mask=attention_mask).loss
        # grads = torch.autograd.grad(output, model.parameters())
        # dict_grads, i = {}, 0 
        # for name, _param in model.named_parameters():
        #     dict_grads[name] = grads[i]
        #     i += 1
        # del grads
        
        # proj_grads = projector.project(grads = dict_grads, model_id=traker.saver.current_model_id)
        # proj_grads /= traker.normalize_factor
        
        
        
        # traker.saver.current_store["grads"][inds] = (
        #     grads.to(traker.dtype).cpu().clone().detach()
        # )

        # loss_grads = traker.gradient_computer.compute_loss_grad(batch)
        # traker.saver.current_store["out_to_loss"][inds] = (
        #     loss_grads.to(traker.dtype).cpu().clone().detach()
        # )
        
        
    # forget_inputs, retain_inputs = inputs
    # input_ids, labels, attention_mask = forget_inputs

    # def process_batch(batch):
    #     return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']


    traker.finalize_features()

    # traker.start_scoring_checkpoint(exp_name='qnli',
    #                                 checkpoint=model.state_dict(),
    #                                 model_id=0)
    #                                 num_targets=VAL_SET_SIZE)
    traker.start_scoring_checkpoint(checkpoint=model.state_dict(),
                                    model_id=0)
    
    
    
    
    for batch in tqdm(loader_val, desc='Scoring..'):
        batch = process_batch(batch)
        batch = [x.cuda() for x in batch]
        traker.score(batch=batch, num_samples=batch[0].shape[0])

    scores = traker.finalize_scores()