# This is package for forget trainer rather than data loader
import os
import csv
import copy
import json 
import deepspeed
import numpy as np

import torch
import torch.nn as nn 
import torch.nn.functional as F

from tqdm import tqdm
from pathlib import Path
from transformers import Trainer
from transformers.integrations.deepspeed import deepspeed_init

from evaluate_all import get_dataloader, get_all_evals
from data_module import get_batch_loss 
from utils import merge_dicts, get_forget_quality, get_model_utility

# Traker is on the TODO list
from trak import TRAKer

def printll(name, inp):
    #print list with 4 decimal for each item
    print(name, [round(x, 4) for x in inp])

class CustomTrainer(Trainer):
    def __init__(self, enable_profiling=False, **kwargs):
        super().__init__(**kwargs)
        self.enable_profiling = enable_profiling
        
        if self.enable_profiling:
            try:
                from torch.profiler import profile, record_function, ProfilerActivity
                self.profiler = profile
                self.record_function = record_function
                self.ProfilerActivity = ProfilerActivity
                self.profile_count = 0
                print("Profiling enabled successfully")
            except ImportError:
                print("Warning: Profiling requested but torch.profiler not available")
                self.enable_profiling = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def training_step(self, model, inputs, num_items_in_batch=None):
        """
        Perform a training step on a batch of inputs.
        Override to add profiling functionality.
        Args:
            model: The model to train
            inputs: The inputs and labels
            num_items_in_batch: Number of items in the batch (for gradient scaling)
        """
        if not self.enable_profiling:
            # Just call super().training_step() instead of implementing it ourselves
            return super().training_step(model, inputs)
        
        # Only profile intermittently to reduce overhead
        if self.state.global_step % 20 == 0 and self.profile_count < 5:
            self.profile_count += 1
            with self.profiler(
                activities=[self.ProfilerActivity.CPU, self.ProfilerActivity.CUDA],
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
                on_trace_ready=self._log_profile
            ) as prof:
                with self.record_function("training_step"):
                    # Call parent implementation
                    loss = super().training_step(model, inputs)
                    
            return loss
        else:
            # Call parent implementation without profiling
            return super().training_step(model, inputs)

    def _log_profile(self, prof):
        """Log profiling results"""
        local_rank = self.args.local_rank
        step = self.state.global_step
        
        # Only log from the main process
        if local_rank <= 0:
            print(f"\n====== Profiling Results for Step {step} ======")
            print(prof.key_averages().table(
                sort_by="cuda_time_total", 
                row_limit=20
            ))
            
            # Also save to file
            prof.export_chrome_trace(f"{self.args.output_dir}/profile_step_{step}.json")
            
            # Log communication stats specifically
            print("\n====== GPU Communication Stats ======")
            for evt in prof.events():
                if "nccl" in evt.name.lower() or "cuda" in evt.name.lower() and "memcpy" in evt.name.lower():
                    print(f"{evt.name}: {evt.cuda_time_total/1000:.2f}ms")

class CustomTrainerForgetting(Trainer):
    def __init__(self, *args, **kwargs):
        self.loss_type = kwargs.pop('forget_loss')
        self.oracle_model = kwargs.pop('oracle_model')
        self.eval_cfg = kwargs.pop('eval_cfg')
        self.score_dict = kwargs.pop('score_dict')
        # Eve: make the pre-calculated score per sample as an argument for trainer.
        super(CustomTrainerForgetting, self).__init__(*args, **kwargs)
        if self.loss_type == "KL" or self.loss_type == "dpo":
            self.oracle_model = self.e_prepare_deepspeed(self.oracle_model)

    def e_prepare_deepspeed(self, model):
        # First check if DeepSpeed plugin exists
        deepspeed_plugin = getattr(self.accelerator.state, 'deepspeed_plugin', None)

        # If DeepSpeed isn't enabled, return the model as is
        if deepspeed_plugin is None:
            print("Warning: DeepSpeed not enabled. Oracle model will not use DeepSpeed.")
            return model

        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        # Rest of the method remains the same...
        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        config_kwargs["optimizer"] = {"type": None}
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        #set the gradients to false for every parameter
        for param in model.parameters():
            param.requires_grad = False

        return model

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):        
        if self.loss_type == "grad_ascent":
            forget_inputs, retain_inputs = inputs
            # Eve: enabled key for score (question) loaded at data_module.TextForgetDatasetQA.
            input_tensors, questions = forget_inputs
            input_ids, labels, attention_mask = input_tensors
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            # Eve: enabled score loaded at initialization of dataloader.CustomTrainerFrogetting

            scores = []
            for q in questions:
                scores.append(self.score_dict[q])

            forget_loss = outputs.loss * np.mean(scores)
            forget_loss = forget_loss * -1
            loss = forget_loss

        elif self.loss_type == "grad_diff":
            forget_inputs, retain_inputs = inputs

            # input_ids, labels, attention_mask = forget_inputs
            # outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            forget_input_tensors, forget_questions = forget_inputs
            input_ids, labels, attention_mask = forget_input_tensors
            # input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            scores = []
            for q in forget_questions:
                scores.append(self.score_dict[q])

            forget_loss = outputs.loss * np.mean(scores)
            # forget_loss = outputs.loss
            forget_loss = forget_loss * -1

            # retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_input_tensors, retain_questions = retain_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_input_tensors

            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = forget_loss + retain_loss

        elif self.loss_type == "KL":
            forget_inputs, retain_inputs = inputs
            forget_input_tensors, forget_questions = forget_inputs
            input_ids, labels, attention_mask = forget_input_tensors
            # input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            scores = []
            for q in forget_questions:
                scores.append(self.score_dict[q])

            forget_loss = outputs.loss * np.mean(scores)
            forget_loss = forget_loss * -1

            retain_input_tensors, retain_questions = retain_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_input_tensors
            with torch.no_grad():
                retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)

            retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
            retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

            current_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            current_probs = F.log_softmax(current_outputs.logits, dim=-1)
            current_probs = current_probs.view(-1, current_outputs.logits.shape[-1])

            #minimum KL divergence
            retain_loss = nn.functional.kl_div(current_probs, retain_probs, reduction='batchmean', log_target=True)
            loss = forget_loss + retain_loss

        elif self.loss_type == "idk":
            ## this baseline is not enabled for data attribution driven unlearning yet
            idk_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs

            #concatenate the inputs. single forward pass is much more efficient
            input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
            labels = torch.cat((idk_labels, retain_labels), dim=0)
            attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)

            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            loss = outputs.loss

        elif self.loss_type == "dpo":

            idk_inputs, forget_inputs, retain_inputs = inputs

            forget_input_tensors, forget_questions = forget_inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_input_tensors

            idk_input_tensors, idk_questions = idk_inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_input_tensors

            idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
            forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)

            with torch.no_grad():
                idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                idk_logits_oracle = idk_outputs_oracle.logits
                forget_logits_oracle = forget_outputs_oracle.logits

            forget_scores = []
            for q in forget_questions:
                forget_scores.append(self.score_dict[q])

            idk_scores = []
            for q in idk_questions:
                idk_scores.append(self.score_dict[q])

            idk_loss_oracle = -1 * get_batch_loss(idk_logits_oracle, idk_labels)* np.mean(idk_scores)
            forget_loss_oracle = -1 * get_batch_loss(forget_logits_oracle, forget_labels)* np.mean(forget_scores)

            idk_loss_current = -1 * get_batch_loss(idk_outputs.logits, idk_labels)* np.mean(idk_scores)
            forget_loss_current = -1 * get_batch_loss(forget_outputs.logits, forget_labels)* np.mean(forget_scores)

            pi_logratios = idk_loss_current - forget_loss_current
            ref_logratios = idk_loss_oracle - forget_loss_oracle

            beta = 0.1
            loss = -F.logsigmoid(beta * (pi_logratios - ref_logratios)).mean()

            outputs = forget_outputs

        return (loss, outputs) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def evaluate(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):
        # if eval is called w/o train, handle model prep here
        if self.is_deepspeed_enabled and self.deepspeed is None:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)
        print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list, self.eval_cfg.split)
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        curr_save_dir = os.path.join(eval_cfg.save_dir, f"checkpoint-{curr_step}")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)
        forget_rate = eval_cfg.split.split('_')[0]
        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                # print(save_filename)
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)
                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                perturb_dataloader = self.accelerator.prepare(perturb_dataloader)
                normalize_gt = False 
                # if 'eval_log' not in eval_task:
                #     normalize_gt = True

                eval_logs = get_all_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, normalize_gt=normalize_gt)

                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)

                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts
                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))
                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))

                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)

                            #delete old files use shutil

                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

            if self.accelerator.is_local_main_process:
                # aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)

                if eval_cfg.retain_result is not None:
                    model_utility = get_model_utility(aggregated_eval_logs)
                    retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                    forget_quality = get_forget_quality(aggregated_eval_logs, retain_result)
                    aggregate_stat = {**model_utility, **forget_quality}

                    # save aggregate_stat as csv
                    with open(os.path.join(curr_save_dir, "aggregate_stat.csv"), 'w') as csvfile:
                        field_names = list(aggregate_stat.keys())
                        writer = csv.DictWriter(csvfile, fieldnames=field_names)
                        writer.writeheader()
                        writer.writerow(aggregate_stat)

def custom_data_collator_forget_dpo(samples):
    idk_samples, forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples]
    rets = []
    for data_type in ["idk", "forget", "retain"]:
        data = eval(data_type + "_samples") 
        input_ids = [s[0][0] for s in data]
        labels = [s[0][1] for s in data]
        attention_mask = [s[0][2] for s in data]
        question = [s[1] for s in data]
        rets.append(((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)), question))
    return rets

def custom_data_collator_forget(samples):
    forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
    rets = []
    for data_type in ["forget", "retain"]:
        data = forget_samples if data_type == "forget" else retain_samples
        input_ids = [s[0][0] for s in data]
        labels = [s[0][1] for s in data]
        attention_mask = [s[0][2] for s in data]
        question = [s[1] for s in data]
        rets.append(((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)), question))
    return rets

def compute_metrics(pred):
    logits, labels = torch.from_numpy(pred.predictions), torch.from_numpy(pred.label_ids)
    preds = torch.from_numpy(pred.predictions.argmax(-1))
    shifted_labels = labels[..., 1:].contiguous()
    acc = torch.mean((preds[..., :-1] == shifted_labels).float())
    loss  = get_loss(logits, labels)
    return {"eval accuracy": acc, "eval loss": loss.item()}

def get_loss(output, labels):
    shifted_labels = labels[..., 1:].contiguous()
    output = output[..., :-1, :].contiguous()

    loss_function = nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_function(output.view(-1, output.size(-1)), shifted_labels.view(-1))

    return loss
