
import argparse
import torch
import numpy as np
import gc
import matplotlib.pyplot as plt
from attacks.mia_utils import get_losses
import pickle
from scipy import stats
import math


def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--init_checkpoint', type=str, help='initial checkpoint')
    parser.add_argument('--out_dir',type=str, help='output_directory')
    parser.add_argument('--max_seq_length', type=int, default=256, help='evaluation length')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--topk', type=int, default=10, help='number of top tokens to select')
    parser.add_argument('--ratio_change', type=float, default=0.01, help='ratio of samples to change')
    parser.add_argument('--data_cache_dir', type=str, help='data cache directory', required=True)
    parser.add_argument('--num_samples', type=int, default=250, help='Number of samples')
    parser.add_argument('--num_nonmembers', type=int, default=256, help='Number of nonmembers')
    parser.add_argument('--score_pretrained', type=bool, default=False, help='Score pretrained model')

    return parser.parse_args(args)

def run_data_extraction(output, output_none, info, model, tokenizer, output_dir, args):
    
    train_prefix_tokens = output['train_prefix_tokens']
    train_selected_samples = output['train_selected_samples']
    train_set = output['train_tokens'][train_selected_samples[:args.num_samples]]
    train_set_none = output_none['train_tokens'][train_selected_samples[:args.num_samples]]

    val_prefix_tokens = output['val_prefix_tokens']
    val_selected_samples = output['val_selected_samples']
    val_set = output['val_tokens'][val_selected_samples[:args.num_samples]]
    val_set_none = output_none['val_tokens'][val_selected_samples[:args.num_samples]]

    selected_tokens = train_prefix_tokens.unique()

    
    def get_exposure(model, x, x_none, n_samples, prefix_lengths):

        x_none = x_none.repeat(n_samples, 1)
        random_prefix = selected_tokens[torch.randint(0, len(selected_tokens), (n_samples * val_prefix_tokens.shape[1],))].reshape(-1, val_prefix_tokens.shape[1])
        x_none = torch.cat([random_prefix, x_none[:, :-random_prefix.shape[1]]], dim=1)
        x = torch.cat([x, x_none], dim=0)
        x = tokenizer(tokenizer.batch_decode(x, skip_special_tokens=True), padding="max_length", truncation=True, max_length=max(prefix_lengths), return_tensors='pt')
        losses = get_losses(model, x.to(model.device), args.batch_size)
        d = {}
        for k in prefix_lengths:
            target_loss = losses[0,:k].mean(-1)
            nonmembers = losses[1:,:k].mean(-1)
            exposure = -torch.log2((nonmembers <= target_loss).float().mean(-1) + 1e-30)
            try:
                theta = stats.distributions.skewnorm.fit(nonmembers.cpu().numpy())    
                exposure_model = -torch.log2(torch.tensor(stats.distributions.skewnorm.cdf(target_loss.item(), *theta)) + 1e-30)
            except Exception as e:
                print('Error: ', e)
                theta = stats.distributions.norm.fit(nonmembers.cpu().numpy())    
                exposure_model = -torch.log2(torch.tensor(stats.distributions.norm.cdf(target_loss.item(), *theta)) + 1e-30)
            
            d[k] = [exposure, exposure_model]
        return d, losses

    def get_exposures(model, samples, none_samples, n_samples, prefix_lengths=(6, 10, 32, 64, 128, 255)):
        exposures = {}
        exposures_model = {}
        losses = []
        for i, (x, x_none) in enumerate(zip(samples, none_samples)):
            d, loss = get_exposure(model, x.unsqueeze(0), x_none.unsqueeze(0), n_samples, prefix_lengths)
            losses.append(loss)
            gc.collect()
            torch.cuda.empty_cache()
            for k in d.keys():
                if k not in exposures:
                    exposures[k] = []
                    exposures_model[k] = []
                exposures[k].append(d[k][0].item())
                exposures_model[k].append(d[k][1].item())
        return exposures, exposures_model, torch.stack(losses, dim=0)

    
    print('train')
    train_exposures, train_exposures_model, train_losses = get_exposures(model, train_set, train_set_none, n_samples=args.num_nonmembers)
    print('val')
    val_exposures, val_exposures_model, val_losses = get_exposures(model, val_set, val_set_none, n_samples=args.num_nonmembers)

    print('train_exposures', train_exposures)
    print('val_exposures', val_exposures)


    
    gc.collect()
    torch.cuda.empty_cache()
            
    pickle.dump({
        'train_exposures': train_exposures,
        'val_exposures': val_exposures,
        'train_exposures_model': train_exposures_model,
        'val_exposures_model': val_exposures_model,
        'train_losses': train_losses,
        'val_losses': val_losses
    }, open(output_dir / "data_extraction.pkl", "wb"))
