from torch.nn.modules.module import Module
from transformers import TrainerCallback
import math
from src import logger
import os
import torch
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
import pickle
from opacus.grad_sample import GradSampleModule
import gc
from collections import Counter


class PPLCallback(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        try:
            logger.info(f"ppl: {math.exp(state.log_history[-1]['loss'])}")
        except Exception:
            pass

        try:
            logger.info(f"eval_ppl: {math.exp(state.log_history[-1]['eval_loss'])}")
        except Exception:
            pass

def _save_state_to_file(model, path):
    underlying_model = model._modules['_module'] if isinstance(model, GradSampleModule) else model
    torch.save(underlying_model.state_dict(), path)


def _load_state_from_file(model, state_path):
    state_dict = torch.load(state_path)
    if isinstance(model, GradSampleModule):
        model._modules['_module'].load_state_dict(state_dict)
        return model
    model.load_state_dict(state_dict)



class AuditorCallback(TrainerCallback):
    def __init__(self, save_dir, list_of_canary_datasets, canary_indices):
        self.save_dir = save_dir
        self.canary_indices = canary_indices
        self.list_of_canary_datasets = list_of_canary_datasets
        Path(self.save_dir).mkdir(exist_ok=True, parents=True)



    @staticmethod
    def compute_loss(model, input_ids, attention_masks, device='cuda', batch_size=256, loss_fct=torch.nn.CrossEntropyLoss(reduce=False)):
        losses = []
        for i in tqdm(range(0, len(input_ids), batch_size)):
            with torch.no_grad():
                batch_input_ids = input_ids[i:i+batch_size].to(device)
                batch_attention_mask = attention_masks[i:i+batch_size].to(device)

                outputs = model(
                        input_ids=batch_input_ids,
                        attention_mask=batch_attention_mask,
                        labels=batch_input_ids,
                        return_dict=True
                    )

                shift_labels = batch_input_ids[..., 1:].contiguous()
                shift_logits = outputs.logits[..., :-1, :].contiguous()

                shift_mask = batch_attention_mask[..., 1:].contiguous()

                # Calculate per-token loss
                loss = loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)), 
                    shift_labels.view(-1)
                )

                loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1))
                losses.extend(loss_per_sample.cpu().tolist())
                del batch_input_ids, batch_attention_mask, outputs, shift_labels, shift_logits, shift_mask, loss, loss_per_sample
                gc.collect()
        return losses



class BlackBoxAuditorCallback(AuditorCallback):
    def __init__(self, save_dir, list_of_canary_datasets, canary_indices):
        super().__init__(save_dir, list_of_canary_datasets, canary_indices)

    def on_train_begin(self, args, state, control, **kwargs):
        model = kwargs["model"]
        logger.info(f"Model type: {type(model)}")
        model.eval()
        scores = self.audit(model)
        with open(Path(self.save_dir) / f"initial_scores.pkl", "wb") as f:
            pickle.dump(
            {
                "canary_indices": self.canary_indices,
                "scores": scores
            }, f)
        model.train()


    def on_train_end(self, args, state, control, **kwargs):
        model = kwargs["model"]            
        logger.info(f"Model type: {type(model)}")
        scores = self.audit(model)
        with open(Path(self.save_dir) / f"final_scores.pkl", "wb") as f:
            pickle.dump(
            {
                "canary_indices": self.canary_indices,
                "scores": scores, 
            }, f
            )


    def audit(self, model, device='cuda'):
        model.eval()
        all_scores = []
        for i, cd in enumerate(self.list_of_canary_datasets):
            logger.info(f"Canary dataset {i+1}")
            losses = self.compute_loss(model, cd['input_ids'], cd['attention_mask'], device=device)
            all_scores.append(losses)
        model.train()
        return all_scores

