import os
os.environ['HF_HOME'] = '...'

import numpy as np
import json
import argparse
import os
import time
os.environ["HF_DATASETS_CACHE"] = "..." #change path
os.environ["TRANSFORMERS_CACHE"] = "..." #change path
#os.environ['CUDA_VISIBLE_DEVICES'] ='0,1'
cache_dir = os.getenv("TRANSFORMERS_CACHE")

import transformers
import itertools
import pandas as pd
import math
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling,DataCollatorWithPadding
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig
from transformers import (
    set_seed,
)
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus.data_loader import DPDataLoader
from transformers import get_scheduler
import wandb
import pickle
import string
from datasets import Dataset, DatasetDict, load_dataset
import torch
import torch.nn as nn
import string
import torch.nn.functional as F
import logging




cache_path = cache_dir

LEARNING_RATE = 0.00005
LR_SCHEDULER = 'linear'
WARMUP_RATIO = 0.025
WARMUP_STEPS = 10

SEED = 42
set_seed(SEED)
EVAL_STRATEGY = 'epoch'
#EVAL_STEPS = 10
EVAL_ACCUMULATION_STEPS = 200
LOGGING_STRATEGY = 'steps'
LOGGING_STEPS = 10
SAVE_STRATEGY = 'epoch'
NUM_TRAIN_EPOCHS = 50

closs = torch.nn.CrossEntropyLoss(reduction='none')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description = "")
parser.add_argument("-d","--dataset", help="Dataset name", type=str, required=True, choices=['customersim', 'enron', 'synbio'])
parser.add_argument("-m","--model_name", help="Model name", type=str, required=True, choices=['pythia','llama2', 'gemma', 'mistral','llama3.1','qwen2.5'])
parser.add_argument("--nr", "--noise_ratio", help="Noise ratio", type=float, default = 0.1)
parser.add_argument("--cg", "--clip_grad", help="Clipping Gradient Norm", type=float, default = 1e-2)
parser.add_argument("-e", "--eval_batch_size", help="eval batch", type=int, default = 32)
parser.add_argument("-s", "--train_size", help="train size", type=int,default = 100)

    
args = parser.parse_args()
  
print("Dataset -- "+str(args.dataset))
print("Model -- "+str(args.model_name))



WANDB_PROJECT_NAME = "privacy-marker-dpsgd-"+str(args.dataset)+"--"+str(args.model_name)
WANDB__SERVICE_WAIT=300

os.environ["WANDB_PROJECT"] = WANDB_PROJECT_NAME
os.environ["WANDB_WATCH"] = "all"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"]="300"
# Set API key (for wandb)
os.environ["WANDB_API_KEY"]="..."


trainsize = args.train_size


run = wandb.init(project=WANDB_PROJECT_NAME)

if not os.path.exists('./artifacts'):
    os.mkdir('./artifacts')

if not os.path.exists('./artifacts/loss_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)):
    os.mkdir('./artifacts/loss_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg))
    
if not os.path.exists('./artifacts/preds_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)):
    os.mkdir('./artifacts/preds_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg))
    


def generate_customersim_data(seed):

    train_sequences = []
    val_sequences = []

    rng = np.random.default_rng(seed)
    train_data = pd.read_csv('.../CustomerSim/data_process/data_csim_train.csv')
    val_data = pd.read_csv('.../CustomerSim/data_process/data_csim_val.csv')

    for tdata in train_data['conv']:
        train_sequences.append(tdata.strip())

    for vdata in val_data['conv']:
        val_sequences.append(vdata.strip())

    #print("Length of train sequences")
    #print(len(train_sequences))

    dataset = Dataset.from_dict(
        {
            "text": train_sequences        }
    )

    
    eval_dataset = Dataset.from_dict(
        {
            "text": train_sequences + val_sequences
        }
    )

    datasets = DatasetDict(
        {
            "train": dataset,
            "test": eval_dataset
        }
    )
    datasets.set_format("torch")
    return datasets

def generate_synbio_data(seed):

    train_sequences = []
    val_sequences = []

    rng = np.random.default_rng(seed)
    train_data = pd.read_csv('.../SynBio/data_synbio_train.csv')
    val_data = pd.read_csv('.../SynBio/data_synbio_val.csv')

    for tdata in train_data['text']:
        train_sequences.append(tdata.strip())

    for vdata in val_data['text']:
        val_sequences.append(vdata.strip())

    #print("Length of train sequences")
    #print(len(train_sequences))

    dataset = Dataset.from_dict(
        {
            "text": train_sequences
        }
    )

    
    eval_dataset = Dataset.from_dict(
        {
            "text": train_sequences + val_sequences #evaluating on both to get privacy measure from training and utility from test
        }
    )

    datasets = DatasetDict(
        {
            "train": dataset,
            "test": eval_dataset
        }
    )
    datasets.set_format("torch")
    return datasets


def tokenize_string(tokenizer,max_tokens,dataset):
    def encode(example: dict):
        sequences = example["text"]
        
        return tokenizer(sequences,max_length=max_tokens,padding='max_length')
    
    return dataset.map(
        encode,
        batched=True,
    )


#metric = load_metric('accuracy',keep_in_memory=True)



def train(dataset,tokenizer,max_tokens,model_name,model,train_batch_size,eval_batch_size):
    
    OUTPUT_DIR = '.../saves-dpsgd/' #change this path

    # Make sure output directory exists
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        
    if not os.path.exists(OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)):
        os.mkdir(OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg))

    encoded_dataset = tokenize_string(tokenizer, max_tokens, dataset)
    training_dataset = encoded_dataset.remove_columns(["text"])
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    
    train_loader = torch.utils.data.DataLoader(training_dataset["train"], shuffle=True, batch_size=train_batch_size, collate_fn=data_collator)
    val_loader = torch.utils.data.DataLoader(training_dataset["test"], shuffle=False, batch_size=eval_batch_size,collate_fn=data_collator)

    
    print(len(train_loader))
    
    model.train()
    #Use Adam
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    # USE SGD
    #optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
    #print("Optimizer", optimizer)
    privacy_engine = PrivacyEngine(accountant = "rdp")


    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        noise_multiplier=args.nr,         
        max_grad_norm=args.cg,
        poisson_sampling=False
    )
    
    num_training_steps = NUM_TRAIN_EPOCHS * len(train_loader)
    wsteps = WARMUP_STEPS
    
    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=wsteps,
        num_training_steps=num_training_steps,
    )
    
    closs = torch.nn.CrossEntropyLoss(reduction='none')
    
    target_delta = 8e-5
    ltokens = []
    acctokens = []
    eval_ltokens = []
    eval_acctokens = []
    predtokens = []
    eval_predtokens = []
    
    ep_acc_tokens = []
    train_loss= []
    epvalues = []
    last_best_loss = np.inf
    
    for epoch in range(0, NUM_TRAIN_EPOCHS):
        losses = []
        loss_store = []
        eval_loss_store = []
        acc_store = []
        eval_acc_store = []
        avg_acc = 0
        eval_acc_it = []
        acc_it = []
        pred_store = []
        eval_pred_store = []
        eval_agg_loss = []
    
        ltokens = []
        acctokens = []
        eval_ltokens = []
        eval_acctokens = []
        predtokens = []
        eval_predtokens = []    
    
        print("Epoch")
        print(epoch)
        start = time.time()
        model.train()
        num_batches = len(train_loader)
        for idx, batch in enumerate(train_loader):

            model.train()
            acc_tokens = 0
            # randomly choose a GPU
            number_of_available_GPUS = torch.cuda.device_count()
            chosen_gpu = torch.randint(0, number_of_available_GPUS, (1,)).item()

            batch = batch.to('cuda:'+str(chosen_gpu))

            inputs = {'input_ids':      batch['input_ids'],
                      'attention_mask': batch['attention_mask'],
                      'labels':         batch['labels']}


        
            outputs = model(**inputs) # output = loss, logits, past_key_values

                    
            loss = outputs.loss
            #print("Loss", loss)
            
            # Empty Cache
            torch.cuda.empty_cache()
            #print("One time")
            loss.backward()
            #print("Second time")
            losses.append(loss.item())

            run.log({"train_batch_loss": loss.item(), "time_step": epoch + idx/num_batches})

            
        
            optimizer.step()
            scheduler.step()
        
            optimizer.zero_grad()

        
        
        train_loss.append(np.mean(losses))
        run.log({"train_loss": loss.item(), "epoch": epoch})
        #ep_acc_tokens.append(avg_acc/len(train_loader))
        # continue
        print("Evaluation")
        #Evaluation Set
        
        #for idx, (batch,eval_batch) in enumerate(zip(val_loader)):
        for idx, eval_batch in enumerate(val_loader):
            
            #batch = batch.to(device)
            eval_batch = eval_batch.to(device)
            model.eval()
            eval_inputs = {'input_ids': eval_batch['input_ids'],
                           'attention_mask': eval_batch['attention_mask'],
                           'labels': eval_batch['labels']}
            eval_outputs = model(**eval_inputs)
            eval_agg_loss.append(eval_outputs.loss.item())
            run.log({"eval_batch_loss": eval_outputs.loss.item(), "batch": idx, "epoch": epoch})
       
            eval_pred = np.asarray(eval_outputs['logits'].detach().cpu())
            eval_preds = np.argmax(eval_pred, axis=2)
            for i in range(eval_preds.shape[0]):
                eval_loss_token = closs(torch.from_numpy(eval_pred[i][:-1,:]),torch.from_numpy(np.asarray(eval_inputs['labels'][i].long()[1:].cpu())))
                eval_loss_store.append(list(eval_loss_token))
                             
                    
            for pind in range(eval_preds.shape[0]):
                #print(eval_inputs['labels'][pind][1:])
                #print(eval_preds[pind][:-1])
                eval_pred_store.append(list(eval_preds[pind]))

        if os.path.exists('./artifacts/loss_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/loss_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy'):
              with open('./artifacts/loss_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/loss_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','rb') as fp:
                  ltokens = (np.load(fp,allow_pickle=True))
                  #print(ltokens.shape)
                  #ltokens_all = np.array(np.load(fp))
                  if len(ltokens.shape) > 2:
                      ltokens = np.append(np.squeeze(ltokens,axis=0),list(eval_loss_store),axis=0)
                  else:
                      ltokens = np.append(ltokens,list(eval_loss_store),axis=0)
        else:
              ltokens.append(list(eval_loss_store))

              
        if os.path.exists('./artifacts/preds_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/pred_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy'):
              with open('./artifacts/preds_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/pred_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','rb') as fp:
                  predtokens = list(np.load(fp,allow_pickle=True))
                  #ltokens_all = np.array(np.load(fp))
                  if len(predtokens.shape) > 2:
                      predtokens = np.append(np.squeeze(predtokens,axis=0),list(eval_pred_store),axis=0)
                  else:
                      predtokens = np.append(predtokens,list(eval_pred_store),axis=0)
        else:
              predtokens.append(list(eval_pred_store))                  
                 
              '''predtokens.append(list(pred_list))
              #print(ltokens)
        else:
              predtokens.append(list(pred_list))'''              
    
    
        with open('./artifacts/loss_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/loss_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','wb') as fp1:
              np.save(fp1,np.array(np.squeeze(ltokens,axis=0)))
              
        with open('./artifacts/preds_tokens_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/pred_tokens_epochs--'+str(epoch)+'--'+str(args.dataset)+'--model--'+str(model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','wb') as fp1:
              np.save(fp1,np.array(np.squeeze(predtokens,axis=0)))
    
        #Saving latest and best checkpoint model
        eval_avg_loss = np.mean(eval_agg_loss)
        run.log({"eval_loss": eval_avg_loss, "epoch": epoch})
        if eval_avg_loss < last_best_loss:
            best_epoch = epoch
            for mfile in os.listdir(OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)):
                os.remove(OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/'+mfile)
            torch.save(model.state_dict(), OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/checkpoint-epoch-'+str(epoch)+'-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.pth')
            last_best_loss = eval_avg_loss
        else:
            #save latest model
            if best_epoch!=(epoch-1):
                os.remove(OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/checkpoint-epoch-'+str(epoch-1)+'-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.pth')
            torch.save(model.state_dict(), OUTPUT_DIR+'/checkpoints-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'/checkpoint-epoch-'+str(epoch)+'-'+str(args.dataset)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.pth')
        
        print("Epoch finished in ")
        print(time.time()-start)

    with open('./artifacts/train_epoch_loss_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','wb') as fp:
        np.save(fp,np.array(train_loss))
        
    
    with open('./artifacts/train_epoch_eps_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'--model--'+str(args.model_name)+'--noise--'+str(args.nr)+'--clip--'+str(args.cg)+'.npy','wb') as fp:
        np.save(fp,np.array(epvalues))

    print("Epoch finished in ")
    print(time.time()-start)    
    #wandb.finish()
    
    
  
tbatch = args.train_batch_size
ebatch = args.eval_batch_size
    
if args.dataset=='customersim':
    dataset = generate_customersim_data(seed=42)
elif args.dataset=='synbio':
    dataset = generate_synbio_data(seed=42)
else:
    print("Dataset not found")
    
        
print("Dataset generated")
        
#peft_config = LoraConfig(inference_mode=False,lora_alpha=args.alpha,lora_dropout=0.1,r=args.rank,task_type='CAUSAL_LM')

if args.model_name=='pythia':
    MODEL_NAME = f"EleutherAI/pythia-1b"
    PRE_TRAINING_CHECKPOINT = 'step143000'
    PRE_TRAINED = ""
    if PRE_TRAINING_CHECKPOINT == 'step143000':
        PRE_TRAINED = "Fully Trained"
    elif PRE_TRAINING_CHECKPOINT == 'step0':
        PRE_TRAINED = "Untrained"
    else:
        PRE_TRAINED = "Partially Trained"
        
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=PRE_TRAINING_CHECKPOINT, cache_dir = cache_path, device_map='auto')

        
elif args.model_name=='llama2':
    MODEL_NAME = f"meta-llama/Llama-2-7b"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto', cache_dir = cache_path)


elif args.model_name=='llama3.1':
    MODEL_NAME = f"meta-llama/Llama-3.1-8B"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',cache_dir = cache_path)

    
elif args.model_name=='qwen2.5':
    MODEL_NAME = f"Qwen/Qwen2.5-7B"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',cache_dir = cache_path)

        
elif args.model_name=='gemma':
    MODEL_NAME = f"google/gemma-2b"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir = cache_path, device_map='auto')

        
    
print("Model loaded")


if args.dataset=='customersim':
    if args.model_name=='pythia':
        max_tokens = 220
    elif args.model_name=='mistral':
        max_tokens=250
    elif args.model_name=='llama2':
        max_tokens = 265
    elif args.model_name=='llama3.1':
        max_tokens = 204
    elif args.model_name=='qwen2.5':
        max_tokens = 218
    else:
        max_tokens = 235
        
if args.dataset=='synbio':
    if args.model_name=='pythia':
        max_tokens = 700
    elif args.model_name=='mistral':
        max_tokens=775
    elif args.model_name=='llama2':
        max_tokens = 825
    elif args.model_name=='llama3.1':
        max_tokens = 672
    elif args.model_name=='qwen2.5':
        max_tokens = 678
    else:
        max_tokens = 645
    
print("Model loaded")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,cache_dir = cache_path)
tokenizer.pad_token = tokenizer.eos_token
#start = time.time()
train(dataset,tokenizer,max_tokens,args.model_name,model,tbatch,ebatch)
#print(time.time()-start)
