# basics
import argparse
from argparse import ArgumentParser
import sys 
import os
import os.path as op
import time
import json
import yaml
import copy
import pandas as pd
import numpy as np
import scipy
import ml_collections
from ml_collections import config_dict

# ML pipeline
import datasets
import torch
import matplotlib.pyplot as plt
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from sklearn import metrics

from datasets import load_dataset, DatasetDict, Dataset, load_from_disk
from transformers import AutoTokenizer
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification # PYTHIA
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification # BLOOM

# Privacy related imports
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus.validators import ModuleValidator

# Add the parent directory to the Python path
import helpers.prepare_dataset as prep_data # import get_dataset, DS

def eval_accuracy(model, data_loader, device):
    model.to(device).eval()
    acc = 0
    for idx, batch in enumerate(data_loader):
        # get input embeddings
        inputs_embeds = batch['input_embeds'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device).clone().detach().numpy()
        batch_size = len(labels)
        # compute loss
        pred_scores = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask).logits.clone().detach().numpy()
        preds = pred_scores.argmax(1)
        print('batch size:', batch_size)
        print('idx:', idx)
        acc += (preds == labels).mean()
    acc = acc / (idx +1)
    return acc

def compute_causal_loss(pred_scores: torch.tensor, labels: torch.tensor):
    # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gpt_neox/modeling_gpt_neox.py#L1050
    # we are doing next-token prediction; shift prediction scores and input ids by one
    shift_logits = pred_scores[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
    return lm_loss

def compute_classification_loss(pred_scores: torch.tensor, labels: torch.tensor):
    loss_fct = CrossEntropyLoss()
    lm_loss = loss_fct(pred_scores, labels.view(-1))
    return lm_loss

def train_dp(model, config, loaders: dict, lr: float=5e-5, eps: float=1e-08, betas: tuple=(0.9, 0.999), weight_decay: float=0):
    
    # Define trainable layers & setup optimizer
    # trainable_layers = [model.gpt_neox.embed_in, model.gpt_neox.emb_dropout, model.gpt_neox.layers, 
    #                     model.gpt_neox.final_layer_norm, model.embed_out]

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # set seed for data loader
    torch.manual_seed(2809)

    model.to(config.DEVICE).train()

    if config.privacy.EPSILON < config.privacy.eps_inf_indicator :
        privacy_engine = PrivacyEngine(accountant = "rdp")
        model, optimizer, data_loader = privacy_engine.make_private_with_epsilon(module=model, 
                                                                                 optimizer=optimizer,
                                                                                 data_loader=loaders['train'],
                                                                                 target_delta=config.privacy.DELTA, 
                                                                                 target_epsilon=config.privacy.EPSILON,
                                                                                 epochs=config.num_epochs,
                                                                                 max_grad_norm=config.privacy.MAX_GRAD_NORM)
        print('Starting private Training ...')
    # load weak model to generate weak labels
    if config.WTS.training:
        print('Loading weak model ...')
        pth = f'saved_models/{config.WTS.weak_model}_{config.dataset_name}_{config.dataset_size}_eps{config.WTS.epsilon}_id{config.id}'
        weak_model = AutoModelForCausalLM.from_pretrained(pth).to(config.DEVICE)
        weak_model.eval()

    losses = []
    for j in range(config.num_epochs):
        # DP training
        if config.privacy.EPSILON < config.privacy.eps_inf_indicator :
            # DP/WTS training
            if config.WTS.training:
                # setup prefix loader to generate labels by weak model
                prefix_loader = iter(loaders['prefix_train'])
            # with BatchMemoryManager(data_loader=data_loader, 
            #                         max_physical_batch_size=config.privacy.MAX_PHYSICAL_BATCH_SIZE, 
            #                         optimizer=optimizer) as memory_safe_data_loader:
            print(f'Epoch {j + 1}/{config.num_epochs}')
            for idx, batch in enumerate(loaders['train']):
                optimizer.zero_grad()
                # get input embeddings
                input_ids = batch['input_ids'].to(config.DEVICE)
                attention_mask = batch['attention_mask'].to(config.DEVICE)
                # compute loss
                pred_scores = model(input_ids=input_ids, attention_mask=attention_mask).logits
                if config.task == 'generation':
                    if config.WTS.training:
                        # labels given by weak model
                        weak_batch = next(prefix_loader)
                        weak_input_ids = weak_batch['input_ids'].to(config.DEVICE)
                        weak_attention_mask = weak_batch['attention_mask'].to(config.DEVICE)
                        labels = weak_model.generate(
                            weak_input_ids, 
                            min_new_tokens=config.suffix_length, 
                            max_new_tokens=config.suffix_length, 
                            temperature=0
                        )
                        # gt_suffixes = config.tokenizer.batch_decode(weak_input_ids, batched=True)
                        # print('input_ids:', input_ids[0:1])
                        # print('weak input_ids:', weak_input_ids[0:1])
                        # print('weak input ids shape:', weak_input_ids.shape)
                        # print('WEAK INPUTS:', gt_suffixes[0:1])
                        # preds = config.tokenizer.batch_decode(labels, batched=True)
                        # print('WEAK LABELS:', preds[0:1])
                        # print('weak label shape:', labels.shape) 
                    else:
                        labels = batch['input_ids'].to(config.DEVICE)
                    lm_loss = compute_causal_loss(pred_scores, labels)
                elif config.task == 'classification':
                    labels = batch['label'].to(config.DEVICE)
                    lm_loss = compute_classification_loss(pred_scores, labels)
                else:
                    raise NotImplementedError
                lm_loss.backward()
                optimizer.step()
                losses.append(lm_loss.clone().detach().cpu().numpy())
        else:
            # vanilla/WTS training
            if config.WTS.training:
                # setup prefix loader to generate labels by weak model
                prefix_loader = iter(loaders['prefix_train'])
            print(f'Epoch {j + 1}/{config.num_epochs}')
            for idx, batch in enumerate(loaders['train']):
                optimizer.zero_grad()
                # get input embeddings
                input_ids = batch['input_ids'].to(config.DEVICE)
                # print('input ids shape:', input_ids.shape)
                attention_mask = batch['attention_mask'].to(config.DEVICE)
                # compute loss
                pred_scores = model(input_ids=input_ids, attention_mask=attention_mask).logits
                if config.task == 'generation': 
                    if config.WTS.training:
                        # labels given by weak model
                        weak_batch = next(prefix_loader)
                        weak_input_ids = weak_batch['input_ids'].to(config.DEVICE)
                        weak_attention_mask = weak_batch['attention_mask'].to(config.DEVICE)
                        labels = weak_model.generate(
                            weak_input_ids, 
                            min_new_tokens=config.suffix_length, 
                            max_new_tokens=config.suffix_length, 
                            temperature=0
                        )
                        # print('input_ids:', input_ids[0:1])
                        # print('weak input_ids:', weak_input_ids[0:1])
                        # print('weak input ids shape:', weak_input_ids.shape)
                        # gt_suffixes = config.tokenizer.batch_decode(weak_input_ids, batched=True)
                        # print('WEAK INPUTS:', gt_suffixes[0:1])
                        # preds = config.tokenizer.batch_decode(labels, batched=True)
                        # print('WEAK LABELS:', preds[0:1])
                        # print('weak label shape:', labels.shape)
                    else:
                        labels = batch['input_ids'].to(config.DEVICE)
                    lm_loss = compute_causal_loss(pred_scores, labels)
                elif config.task == 'classification':
                    labels = batch['label'].to(config.DEVICE)
                    lm_loss = compute_classification_loss(pred_scores, labels)
                else:
                    raise NotImplementedError
                lm_loss.backward()
                optimizer.step()
                # check model output 
                losses.append(lm_loss.clone().detach().cpu().numpy())
                # preds_wts = model.generate(weak_input_ids, max_new_tokens=config.suffix_length, temperature=0)
                # print('WTS preds:', preds_wts[0:1])
                # preds_wts = config.tokenizer.batch_decode(preds_wts[0:1], batched=True)
                # print('WTS preds:', preds_wts[0:1])

    return model.eval(), losses

def train_model(config):
    
    # Get dataset
    dataset = prep_data.get_dataset(config)
    
    # Tokenize datasets
    tokenized = dataset.map(tokenize_text, batched=True, batch_size=None)
    tokenized_pre = dataset.map(tokenize_prefix, batched=True, batch_size=None)
    print(tokenized)
    
    # Set format right
    if config.task == 'generation':
        tokenized.set_format("torch", columns=["input_ids", "attention_mask"])
        tokenized_pre.set_format("torch", columns=["input_ids", "attention_mask"])
    elif config.task == 'classification':
        tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
        tokenized_pre.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    else:
        raise NotImplementedError
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # Load model
    if any(x in config.model_path for x in ['pythia']):
        if config.task == 'generation':
            model = GPTNeoXForCausalLM.from_pretrained(config.model_path).to(config.DEVICE)
            model.config.pad_token_id = model.config.eos_token_id
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    # set the generator
    # gen = torch.Generator()
    # gen.manual_seed(config.id)

    if config.WTS.training:
        partition_key = 'train2'
    else:
        partition_key = 'train1'

    # Train data loader
    train_dataset = prep_data.DS(
                       tokenized, 
                       partition_key=partition_key,
                       max_seq_len=config.max_seq_len)

    # random_sampler_train = torch.utils.data.RandomSampler(train_dataset, generator=gen)
    
    # shuffle=False : ensures that weak batch and batch input ids match
    train_loader = DataLoader(dataset=train_dataset,
                                batch_size=config.batch_size, 
                                shuffle=False,
                                drop_last=False) 

    # Prefix (train) data loader
    prefix_train_dataset = prep_data.DS(tokenized_pre, 
                                        partition_key=partition_key,
                                        max_seq_len=config.max_seq_len)
    
    # random_sampler_train_prefix = torch.utils.data.RandomSampler(prefix_train_dataset, generator=gen)

    prefix_train_loader = DataLoader(dataset=prefix_train_dataset,
                                batch_size=config.batch_size, 
                                shuffle=False,
                                drop_last=False) 


    # Train Vanilla/WTS/DP model
    loaders = {'train': train_loader, 'prefix_train': prefix_train_loader}

    if config.skip_training:
        print("2) Skipping the finetuning stage.\n Returning the pretrained model ...")
    else:
        print("2) Finetune the model ...")
        model, losses = train_dp(model, 
                                config,
                                loaders)
        
    return model, dataset

def measure_perplexity(model, encodings, config, stride=256):
    ''' https://huggingface.co/docs/transformers/en/perplexity '''
    if config.privacy.EPSILON < config.privacy.eps_inf_indicator :
        max_length = model._module.config.max_position_embeddings
    else:
        max_length = model.config.max_position_embeddings
    seq_len = encodings['input_ids'].size(1)
    model.to(config.DEVICE)
    nlls = []
    prev_end_loc = 0
    for begin_loc in range(0, seq_len, stride):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings['input_ids'][:, begin_loc:end_loc].to(config.DEVICE)
        print('input id shape:', input_ids.shape)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100
        for i in range(input_ids.shape[0]):
            with torch.no_grad():
                outputs = model(input_ids[i,:].reshape(1,-1).to(config.DEVICE), labels=target_ids[i,:].reshape(1,-1))
                # loss is calculated using CrossEntropyLoss which averages over valid labels
                # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
                # to the left by 1.
                neg_log_likelihood = outputs.loss
                print('test loss:', neg_log_likelihood)
        
            nlls.append(neg_log_likelihood)
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
    
    print('mean negative log likelihood:', torch.stack(nlls).mean())
    ppl = float(torch.exp(torch.stack(nlls).mean()).detach().cpu().numpy())
    return ppl, torch.stack(nlls).detach().cpu().numpy().tolist()

def measure_verbatim_memorization(tok_txts: dict, model, config): 
    fracs = []
    preds = []
    model.to(config.DEVICE)
    model.eval()
    for i in range(config.mem_test):
        if config.privacy.EPSILON < config.privacy.eps_inf_indicator :
            outs = model._module.generate(tok_txts['prefix']['input_ids'][i,:].to(config.DEVICE).reshape(1,-1), 
                                min_new_tokens=config.suffix_length,
                                max_new_tokens=config.suffix_length, 
                                temperature=0)
        else:
            outs = model.generate(tok_txts['prefix']['input_ids'][i,:].to(config.DEVICE).reshape(1,-1), 
                    min_new_tokens=config.suffix_length,
                    max_new_tokens=config.suffix_length, 
                    temperature=0)
        pred = outs[:,-config.suffix_length::] 
        preds.append(pred)
        fracs.append((pred == tok_txts['suffix']['input_ids'][i,:].to(config.DEVICE)).sum()/config.suffix_length)
    fracs = torch.tensor(fracs).float().cpu()
    preds = torch.cat(preds)
    gt_suffixes = tok_txts['suffix']['input_ids'][0:config.mem_test,:]
    print('Reconstruction rate:', fracs.mean())
    return fracs, preds, gt_suffixes

def load_config(config_path):
    with open(config_path, "r") as f:
        config_dict = yaml.safe_load(f)
    return ml_collections.ConfigDict(config_dict)

def load_yaml(file_path):
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    return data

if __name__ == '__main__':

        # Parsing Arguments
    parser = ArgumentParser()
    # parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--dataset_name', default=None, type=str, help="One of >Enron<, ...")
    parser.add_argument('--dataset_size', default=15000, type=int, help="All experiments are run with default size of n. Make sure to train new models & rerun eval.")
    parser.add_argument('--model_path', default=None, type=str, help="Small model: >EleutherAI/pythia-14m<")
    parser.add_argument('--epsilon', default=None, type=float, help="Training epsilon: >EleutherAI/pythia-14m<")
    parser.add_argument('--config', default=None, type=str)
    
    arg_ = parser.parse_args()
    if arg_.dataset_name is None:
        raise NameError("Include a dataset_name in the argument please.")
    if arg_.dataset_size is None:
        raise NameError("Include a dataset_size in the argument please.")
    if arg_.model_path is None:
        raise NameError("Include a model_path in the argument please.")
    if arg_.epsilon is None:
        raise NameError("Include a epsilon in the argument please.")
    if arg_.config is None:
        raise NameError("Include a config_file in the argument please.")

    # Getting configurations
    yaml_data = load_yaml(arg_.config)
    config = ml_collections.ConfigDict(yaml_data)

    # specify extra config
    config.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    config.dataset_name = arg_.dataset_name
    config.model_path = arg_.model_path
    config.MODEL_NAME = config.model_path.split('/')[1]

    print(f'Saving trained model: {config.save_flag}')

    if 'pythia' in config.model_path:
        config.model_name = "pythia"
    else:
        raise NotImplementedError

    # setup tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_path, padding_side='left')
    tokenizer.pad_token = tokenizer.eos_token  # needs to be added for pythia models
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.truncation_side = 'left'
    tokenizer.padding_side = 'left'

    # training configs
    config.tokenizer = tokenizer
    config.dataset_size = arg_.dataset_size
    config.mem_test = int(config.dataset_size/3)
    
    # privacy params
    config.privacy = config_dict.ConfigDict()
    config.privacy.BATCH_SIZE = config.batch_size 
    config.privacy.MAX_PHYSICAL_BATCH_SIZE = 20 #20
    config.privacy.MAX_GRAD_NORM = 10 #1
    config.privacy.EPSILON = arg_.epsilon
    config.privacy.DELTA = 1 / config.dataset_size
    config.privacy.LOGGING_INTERVAL = 5000
    config.privacy.eps_inf_indicator = 5000000000 #499


    print('WTS training?', config.WTS.training)
    # if config.WTS.training:
    # assert arg_.epsilon > config.privacy.eps_inf_indicator 
    print('------------------------------------------------------')
    print(f'Weak-to-Strong Training\n weak model epsilon: {config.WTS.epsilon}\n weak model architecture: {config.WTS.weak_model}, strong model architecture:{config.MODEL_NAME}')
    print('------------------------------------------------------')


    def tokenize_text(batch):
        return tokenizer(batch["text"], truncation=True, padding=True, max_length=config.max_seq_len)

    def tokenize_prefix(batch):
        return tokenizer(batch["prefixes"], truncation=True, padding=True, max_length=config.max_seq_len)

    def tokenize_suffix(batch):
        return tokenizer(batch["suffixes"], truncation=True, padding=True, max_length=config.max_seq_len)

    # do train + eval loop for #reps
    collect_results = {}
    for rep in range(config.reps):
        config.id = rep
        # train model 
        print(f'Run Train + Eval at iteration {rep}')
        model, dataset = train_model(config)
        
        # save model & corresponding config file
        if config.save_flag:
            if config.privacy.EPSILON  < config.privacy.eps_inf_indicator :
                # since we do DP training, we need to get model from ._module
                pth = f'saved_models/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_id{rep}/'
                os.makedirs(pth, exist_ok=True)
                model._module.config.to_json_file(pth + 'config.json')
                torch.save(model._module.state_dict(), pth + 'pytorch_model.bin')
            else:
                # here we save model from vanilla training
                if config.WTS.training:
                    pth = f'saved_models/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_weakmodel{config.WTS.weak_model}_weakeps{config.WTS.epsilon}_id{rep}/'
                    os.makedirs(pth, exist_ok=True)
                    model.config.to_json_file(pth + 'config.json')
                    torch.save(model.state_dict(), pth + 'pytorch_model.bin')
                else:
                    pth = f'saved_models/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_id{rep}/'
                    os.makedirs(pth, exist_ok=True)
                    model.config.to_json_file(pth + 'config.json')
                    torch.save(model.state_dict(), pth + 'pytorch_model.bin')


        # do evaluation
        results = {}
        for data_type in ['test', 'train1', 'train2']:
            tokenized_text_test = dataset[data_type].map(tokenize_text, batched=True, batch_size=None)
            tokenized_text_test.set_format("torch", columns=["input_ids"])
            # Perplexity
            ppl, nlls = measure_perplexity(model, tokenized_text_test, config)
            print('Evaluation Part 1): Model performance ...')
            print(f'Perplexity {ppl} for {data_type}')
            if data_type == 'test':
                results['perplexity'] = ppl
            else:
                results[f'perplexity_{data_type}'] = ppl
            results[f'nlls_{data_type}'] = nlls
            
        # train
        for i in range(1,3):
            # tokenize
            tokenized_text = dataset[f'train{i}'].map(tokenize_text, batched=True, batch_size=None)
            tokenized_text.set_format("torch", columns=["input_ids"])
            tokenized_pre = dataset[f'train{i}'].map(tokenize_prefix, batched=True, batch_size=None)
            tokenized_pre.set_format("torch", columns=["input_ids"])
            tokenized_suff = dataset[f'train{i}'].map(tokenize_suffix, batched=True, batch_size=None)
            tokenized_suff.set_format("torch", columns=["input_ids"])
            tok_txts = {}
            tok_txts['text'] = tokenized_text
            tok_txts['prefix'] = tokenized_pre
            tok_txts['suffix'] = tokenized_suff
            # memorization
            print(f'Evaluation Part 2-{i}): Memorization ...')
            fracs, preds, gt_suffixes = measure_verbatim_memorization(tok_txts, model, config)
            # collect results to save
            fracs = list(fracs.numpy().astype(float))
            gt_suffixes = tokenizer.batch_decode(tokenized_suff['input_ids'][0:config.mem_test,:], batched=True)
            preds = tokenizer.batch_decode(preds, batched=True)
            results[f'extraction{i}'] = fracs
            results[f'gt_suffixes{i}'] = gt_suffixes
            results[f'predictions{i}'] = preds 

        collect_results[str(rep)] = results 
        del model
        del dataset

    # save results
    # weak to strong results
    if config.WTS.training: 
        pth = f'results/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_weakmodel{config.WTS.weak_model}_weakeps{config.WTS.epsilon}_epochs{config.num_epochs}.json'
        with open(pth, 'w') as fp:
            json.dump(collect_results, fp, indent=4)
    # DP or vanilla training results
    else:
        if config.skip_training:
            pth = f'results/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_baseline.json'
        else:
            pth = f'results/{config.MODEL_NAME}_{config.dataset_name}_{config.dataset_size}_eps{config.privacy.EPSILON}_epochs{config.num_epochs}.json'
        with open(pth, 'w') as fp:
            json.dump(collect_results, fp, indent=4)