from transformers import TrainerCallback
import math
from llm_logger import train_logger
import os
import torch
import numpy as np
from attacks.mia_utils import get_losses
import sys
import pickle
import gc
import scipy.stats as stats



def get_exposure(model, x, x_none, n_samples, prefix_lengths, prefix_tokens, selected_tokens, tokenizer, batch_size):
    x_none = x_none.repeat(n_samples, 1)
    random_prefix = selected_tokens[torch.randint(0, len(selected_tokens), (n_samples * prefix_tokens.shape[1],))].reshape(-1, prefix_tokens.shape[1])
    x_none = torch.cat([random_prefix, x_none[:, :-random_prefix.shape[1]]], dim=1)
    x = torch.cat([x, x_none], dim=0)
    x = tokenizer(tokenizer.batch_decode(x, skip_special_tokens=True), padding="max_length", truncation=True, max_length=max(prefix_lengths), return_tensors='pt')
    losses = get_losses(model, x.to(model.device), batch_size, disable_tqdm=True)
    d = {}
    for k in prefix_lengths:
        target_loss = losses[0,:k].mean(-1)
        nonmembers = losses[1:,:k].mean(-1)
        exposure = -torch.log2((nonmembers <= target_loss).float().mean(-1) + 1e-30)
        try:
            theta = stats.distributions.skewnorm.fit(nonmembers.cpu().numpy())    
            exposure_model = -torch.log2(torch.tensor(stats.distributions.skewnorm.cdf(target_loss.item(), *theta)) + 1e-30)
        except Exception as e:
            print('Error: ', e)
            theta = stats.distributions.norm.fit(nonmembers.cpu().numpy())    
            exposure_model = -torch.log2(torch.tensor(stats.distributions.norm.cdf(target_loss.item(), *theta)) + 1e-30)
        
        d[k] = [exposure, exposure_model]
    return d, losses

def get_exposures(model, samples, none_samples, n_samples, prefix_tokens, selected_tokens, tokenizer, batch_size, prefix_lengths=(6, 10, 32, 64, 128, 255)):
    exposures = {}
    exposures_model = {}
    losses = []
    for i, (x, x_none) in enumerate(zip(samples, none_samples)):
        d, loss = get_exposure(model, x.unsqueeze(0), x_none.unsqueeze(0), n_samples, prefix_lengths, prefix_tokens, selected_tokens[i], tokenizer, batch_size)
        losses.append(loss)
        gc.collect()
        torch.cuda.empty_cache()
        for k in d.keys():
            if k not in exposures:
                exposures[k] = []
                exposures_model[k] = []
            exposures[k].append(d[k][0].item())
            exposures_model[k].append(d[k][1].item())
    return exposures, exposures_model, torch.stack(losses, dim=0)

class PPLCallback(TrainerCallback):
    def __init__(self, output, output_none, tokenizer, max_seq_length, model):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.counter = 0
        self.model = model
        self.prefix_type = output['prefix_type']
        self.dataset = output['dataset']
        self.tokens ={
            name:self.tokenizer(self.tokenizer.batch_decode(output[name], skip_special_tokens=True), padding="max_length", truncation=True, max_length=self.max_seq_length, return_tensors='pt')
                for name in ['train_tokens', 'val_tokens', 'z_tokens']
        }
        self.tokens_none ={
            name:self.tokenizer(self.tokenizer.batch_decode(output_none[name], skip_special_tokens=True), padding="max_length", truncation=True, max_length=self.max_seq_length, return_tensors='pt')
                for name in ['train_tokens', 'val_tokens', 'z_tokens']
        }
        print('train_prefix_tokens:', output['train_prefix_tokens'])
        if output['train_prefix_tokens'] is not None:
            self.train_prefix_tokens = output['train_prefix_tokens']
            self.train_selected_samples = output['train_selected_samples']
            self.train_selected_tokens = output['train_selected_tokens']
            self.val_prefix_tokens = output['val_prefix_tokens']
            self.val_selected_samples = output['val_selected_samples']
            self.val_selected_tokens = output['val_selected_tokens']
            
            print('train_prefix_tokens:', self.train_prefix_tokens.shape)
            print('train_selected_samples:', self.train_selected_samples.shape, self.train_selected_samples)
            print('train_selected_tokens:', self.train_selected_tokens.shape)
            print('val_prefix_tokens:', self.val_prefix_tokens.shape)
            print('val_selected_samples:', self.val_selected_samples.shape, self.val_selected_samples)
            print('val_selected_tokens:', self.val_selected_tokens.shape)
        else:
            self.train_prefix_tokens = None

    def on_evaluate(
        self,
        args,
        state,
        control,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(
            args.output_dir, f"epoch-{self.counter}"
        )       
        os.makedirs(checkpoint_folder, exist_ok=True)
        with torch.no_grad():
                
            if self.train_prefix_tokens is not None:
                print('Running train exposures')
                train_exposures, train_exposures_model, train_losses = get_exposures(model=self.model, 
                                                                                     samples=self.tokens['train_tokens'].input_ids[self.train_selected_samples],
                                                                                     none_samples=self.tokens_none['train_tokens'].input_ids[self.train_selected_samples],
                                                                                     n_samples=63,
                                                                                     prefix_tokens=self.train_prefix_tokens,
                                                                                     selected_tokens=self.train_selected_tokens,
                                                                                     tokenizer=self.tokenizer,
                                                                                     batch_size=8)
                pickle.dump(train_exposures, open(f'{checkpoint_folder}/train_exposures.pkl', 'wb'))
                pickle.dump(train_exposures_model, open(f'{checkpoint_folder}/train_exposures_model.pkl', 'wb'))
                np.save(f'{checkpoint_folder}/train_losses.npy', train_losses.numpy(force=True))

            print('Running losses')
            # Save the losses ._module.module
            for name in ['train_tokens', 'val_tokens', 'z_tokens']:
                losses = get_losses(self.model, self.tokens[name], 8, disable_tqdm=True).numpy(force=True)
                np.save(f'{checkpoint_folder}/losses_{name}.npy', losses)
                print(f"mean losses for {name}: {np.mean(losses)}")

            
        self.counter+=1
        pickle.dump({
            'environ': dict(os.environ),
            'argv': sys.argv,
        }, open(f"{checkpoint_folder}/info.pkl", "wb"))

        return control




    def on_log(self, args, state, control, **kwargs):
        try:
            train_logger.info(f"ppl: {math.exp(state.log_history[-1]['loss'])}")
        except Exception:
            pass

        try:
            train_logger.info(f"eval_ppl: {math.exp(state.log_history[-1]['eval_loss'])}")
        except Exception:
            pass

