import os
os.environ['HF_HOME'] = '...'
import numpy as np
import argparse
import time

os.environ["HF_DATASETS_CACHE"] = "..."
os.environ["TRANSFORMERS_CACHE"] = "..."

#os.environ['CUDA_VISIBLE_DEVICES'] ='0,1'
cache_dir = os.getenv("TRANSFORMERS_CACHE")
import transformers
import itertools
import pandas as pd
import math
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling,DataCollatorWithPadding
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig
from transformers import (
    set_seed,
)
import wandb
import pickle
import string
from datasets import Dataset, DatasetDict, load_dataset
import torch
import string
import torch.nn.functional as F
import logging

cache_path = cache_dir

LEARNING_RATE = 0.00025
LR_SCHEDULER = 'linear'
#WARMUP_RATIO = 0.025
WARMUP_STEPS = 10

SEED = 42
set_seed(SEED)
EVAL_STRATEGY = 'epoch'
#EVAL_STEPS = 10
EVAL_ACCUMULATION_STEPS = 200 # https://discuss.huggingface.co/t/cuda-out-of-memory-during-evaluation-but-training-is-fine/1783/2
LOGGING_STRATEGY = 'steps'
LOGGING_STEPS = 10
SAVE_STRATEGY = 'epoch'
NUM_TRAIN_EPOCHS = 50

closs = torch.nn.CrossEntropyLoss(reduction='none')

if not os.path.exists('./artifacts'):
    os.mkdir('./artifacts')


parser = argparse.ArgumentParser(description = "")
parser.add_argument("-d","--dataset", help="Dataset name", type=str, required=True, choices=['customersim', 'enron', 'synbio'])
parser.add_argument("-m","--model_name", help="Model name", type=str, required=True, choices=['pythia','llama2', 'gemma', 'mistral','llama3.1','qwen2.5'])
parser.add_argument("-t", "--train_batch_size", help="training batch", type=int, default = 32)
parser.add_argument("-e", "--eval_batch_size", help="eval batch", type=int, default = 32)
parser.add_argument("-s", "--train_size", help="train size", type=int,default = 100)

    
args = parser.parse_args()
  
print("Dataset -- "+str(args.dataset))
print("Model -- "+str(args.model_name))

if not os.path.exists('./artifacts'):
    os.mkdir('./artifacts')

if not os.path.exists('./artifacts/loss_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)):
    os.mkdir('./artifacts/loss_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name))
    
if not os.path.exists('./artifacts/preds_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)):
    os.mkdir('./artifacts/preds_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name))


trainsize = args.train_size

def generate_customersim_data(seed):

    train_sequences = []
    val_sequences = []

    rng = np.random.default_rng(seed)
    train_data = pd.read_csv('.../CustomerSim/data_process/data_csim_train.csv')
    val_data = pd.read_csv('.../CustomerSim/data_process/data_csim_val.csv')

    for tdata in train_data['conv']:
        train_sequences.append(tdata.strip())

    for vdata in val_data['conv']:
        val_sequences.append(vdata.strip())

    #print("Length of train sequences")
    #print(len(train_sequences))

    dataset = Dataset.from_dict(
        {
            "text": train_sequences
        }
    )

    
    eval_dataset = Dataset.from_dict(
        {
            "text": train_sequences + val_sequences #evaluating on both to get privacy measure from training and utility from test
        }
    )

    datasets = DatasetDict(
        {
            "train": dataset,
            "test": eval_dataset
        }
    )
    datasets.set_format("torch")
    return datasets

def generate_synbio_data(seed):

    train_sequences = []
    val_sequences = []

    rng = np.random.default_rng(seed)
    train_data = pd.read_csv('.../SynBio/data_synbio_train.csv')
    val_data = pd.read_csv('.../SynBio/data_synbio_val.csv')

    for tdata in train_data['text']:
        train_sequences.append(tdata.strip())

    for vdata in val_data['text']:
        val_sequences.append(vdata.strip())

    #print("Length of train sequences")
    #print(len(train_sequences))

    dataset = Dataset.from_dict(
        {
            "text": train_sequences
        }
    )

    
    eval_dataset = Dataset.from_dict(
        {
            "text": train_sequences+ val_sequences
        }
    )

    datasets = DatasetDict(
        {
            "train": dataset,
            "test": eval_dataset
        }
    )
    datasets.set_format("torch")
    return datasets


def tokenize_string(tokenizer,max_tokens,dataset):
    def encode(example: dict):
        sequences = example["text"]
        
        #return tokenizer(sequences,max_length=220,padding='max_length') #Maximum number of tokens for CustomerSim = 218
        return tokenizer(sequences,max_length=max_tokens,padding='max_length')
    
    return dataset.map(
        encode,
        batched=True,
    )


#metric = load_metric('accuracy',keep_in_memory=True)


def compute_metrics(eval_pred):
    return {'f1': 1}


def train(dataset,tokenizer,max_tokens,model_name,model,train_batch_size,eval_batch_size):
    
    
    def preprocess_logits_for_metrics(logits, labels):

        '''print("Logits and Labels")
        print(logits.shape)
        print(labels.shape)'''
        ltokens = []
        labels_tokens = []
        ltokens_all = []
        labels_tokens_all = []
        pred_tokens = []
        pred_tokens_all = []

        for i in range(logits.shape[0]):
             loss_token = closs(logits[i][:-1,:],labels[i][1:])
             ltokens.append(list(loss_token.cpu()))
             #labels_tokens.append(list(labels[i].cpu()))

        #print(trainer.state.epoch)
        
        #Save loss and predictions per token only when it is a multiple of 5 epochs
        #if int(trainer.state.epoch)==1 or int(trainer.state.epoch) % 5 ==0:
            
        if os.path.exists('./artifacts/loss_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/loss_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy'):
            with open('./artifacts/loss_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/loss_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy','rb') as fp:
                ltokens_all = np.array(np.load(fp))
                if len(ltokens_all.shape) > 2:
                    ltokens_all = np.append(np.squeeze(ltokens_all,axis=0),list(ltokens),axis=0)
                else:
                    ltokens_all = np.append(ltokens_all,list(ltokens),axis=0)
        else:
            ltokens_all.append(list(ltokens))

    
        with open('./artifacts/loss_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/loss_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy','wb') as fp1:
            np.save(fp1,ltokens_all)
        #pred = np.asarray(logits.detach().cpu())
        #pred_tokens = np.argmax(pred, axis=2)
        pred_tokens = torch.argmax(logits, dim=-1).cpu()
        if os.path.exists('./artifacts/preds_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/pred_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy'):
            with open('./artifacts/preds_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/pred_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy','rb') as fp:
                pred_tokens_all = np.array(np.load(fp))
                if len(pred_tokens_all.shape) > 2:
                    pred_tokens_all = np.append(np.squeeze(pred_tokens_all,axis=0),list(pred_tokens),axis=0)
                else:
                    pred_tokens_all = np.append(pred_tokens_all,list(pred_tokens),axis=0)
        else:
            pred_tokens_all.append(list(pred_tokens))

    
        with open('./artifacts/preds_data_'+str(args.dataset)+'--train_size--'+str(args.train_size)+'_model_'+str(args.model_name)+'/pred_tokens_epoch-'+str(int(trainer.state.epoch))+'--'+str(args.dataset)+'--model--'+str(args.model_name)+'.npy','wb') as fp1:
            np.save(fp1,pred_tokens_all) 

        pred_ids = torch.argmax(logits, dim=-1)
        return pred_ids, labels    
    
    WANDB_PROJECT_NAME = "privacy-marker-fft-"+str(args.dataset)+"--"+str(args.model_name)
    WANDB__SERVICE_WAIT=300
    OUTPUT_DIR = '.../saves-fft/'

    # Make sure output directory exists
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    os.environ["WANDB_PROJECT"] = WANDB_PROJECT_NAME
    os.environ["WANDB_WATCH"] = "all"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    # Set API key (for wandb)
    os.environ["WANDB_API_KEY"]="..."
    RUN_NAME = "fft--"+str(args.dataset)+"--model-" + str(args.model_name) + "-epochs-" + str(NUM_TRAIN_EPOCHS)
    OUTPUT_DIR = os.path.join(OUTPUT_DIR, RUN_NAME)
    wandb.init(project=WANDB_PROJECT_NAME)
    wandb.run.name = RUN_NAME
    
    params_dict = {
    'LEARNING_RATE': LEARNING_RATE,
    'SEED': SEED,
    'LR_SCHEDULER': LR_SCHEDULER,
    'WARMUP_STEPS' : WARMUP_STEPS,
    'NUM_TRAIN_EPOCHS': NUM_TRAIN_EPOCHS
    }
    
    wandb.config.update(params_dict)


    encoded_dataset = tokenize_string(tokenizer, max_tokens, dataset)
    training_dataset = encoded_dataset.remove_columns(["text"])
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir = OUTPUT_DIR,
        eval_strategy = EVAL_STRATEGY,
        # eval_steps = EVAL_STEPS,
        eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        learning_rate = LEARNING_RATE,
        lr_scheduler_type = LR_SCHEDULER,
        save_total_limit = 1, #Save latest checkpoint
        load_best_model_at_end = True, #Save best checkpoint too alongside
        warmup_steps = WARMUP_STEPS,
        num_train_epochs = NUM_TRAIN_EPOCHS,
        save_strategy = SAVE_STRATEGY,
        run_name = RUN_NAME,
        report_to="wandb",
        logging_dir = OUTPUT_DIR,
        logging_strategy = LOGGING_STRATEGY,
        logging_steps = LOGGING_STEPS,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=training_dataset["train"],
        eval_dataset=training_dataset["test"],
        data_collator=data_collator,
        compute_metrics=compute_metrics,    
        preprocess_logits_for_metrics = preprocess_logits_for_metrics
    )
    
    start = time.time()
    trainer.train()
    print(time.time()-start)    
    wandb.finish()
  
tbatch = args.train_batch_size
ebatch = args.eval_batch_size
    
if args.dataset=='customersim':
    dataset = generate_customersim_data(seed=42)
elif args.dataset=='synbio':
    dataset = generate_synbio_data(seed=42)
else:
    print("Dataset not found")
        
print("Dataset generated")
        

if args.model_name=='pythia':
    MODEL_NAME = f"EleutherAI/pythia-1b"
    PRE_TRAINING_CHECKPOINT = 'step143000'
    PRE_TRAINED = ""
    if PRE_TRAINING_CHECKPOINT == 'step143000':
        PRE_TRAINED = "Fully Trained"
    elif PRE_TRAINING_CHECKPOINT == 'step0':
        PRE_TRAINED = "Untrained"
    else:
        PRE_TRAINED = "Partially Trained"
        
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=PRE_TRAINING_CHECKPOINT, cache_dir = cache_path)
    #model = get_peft_model(model, peft_config)
        
elif args.model_name=='llama2':
    MODEL_NAME = f"meta-llama/Llama-2-7b"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',cache_dir = cache_path)
    #model = get_peft_model(model, peft_config)
elif args.model_name=='llama3.1':
    MODEL_NAME = f"meta-llama/Llama-3.1-8B"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',cache_dir = cache_path)
    
elif args.model_name=='qwen2.5':
    MODEL_NAME = f"Qwen/Qwen2.5-7B"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',cache_dir = cache_path)
        
elif args.model_name=='gemma':
    MODEL_NAME = f"google/gemma-2b"
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir = cache_path)
    #model = get_peft_model(model, peft_config)

if args.dataset=='customersim':
    if args.model_name=='pythia':
        max_tokens = 220
    elif args.model_name=='llama2':
        max_tokens = 265
    elif args.model_name=='llama3.1':
        max_tokens = 204
    elif args.model_name=='qwen2.5':
        max_tokens = 218
    else: #gemma
        max_tokens = 235
        
if args.dataset=='synbio':
    if args.model_name=='pythia':
        max_tokens = 700
    elif args.model_name=='llama2':
        max_tokens = 825
    elif args.model_name=='llama3.1':
        max_tokens = 672
    elif args.model_name=='qwen2.5':
        max_tokens = 678
    else: #gemma
        max_tokens = 645
    
print("Model loaded")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,cache_dir = cache_path)
tokenizer.pad_token = tokenizer.eos_token
    
train(dataset,tokenizer,max_tokens,args.model_name,model,tbatch,ebatch)