import os
import sys
import math
import random
import re
import argparse
import copy
from copy import deepcopy as cp
from collections import OrderedDict
import dotenv
from datetime import datetime

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.init as init
from torch.autograd.function import InplaceFunction
from torch.autograd import Variable

from peft import get_peft_model, LoraConfig

from transformers import AutoModelWithLMHead, AutoModelForCausalLM, \
    EncoderDecoderModel, pipeline, AutoModelForTokenClassification, \
    AutoTokenizer, BertTokenizer, AutoModelForSequenceClassification, AutoModel, \
    AutoModelForSeq2SeqLM

import wandb

from tqdm import tqdm

from data.data_utils import remove_emojis, remove_html, remove_email, clean_tweets, calculate_CMI, language_identification
from data.custom_tokenizers import custom_wp_tokenizer
from data.datasets import TransformerDataset, PreTrainedTransformerDataset
from models.utils import calculate_metrics, get_model, create_masks, generate_pretrained_transformer
from trainer import meta_trainer, trainer_without_metalearning, pretrained_model_trainer

from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
if __name__ == '__main__':
    dotenv.load_dotenv()

    parser = argparse.ArgumentParser(prog='Trainer',conflict_handler='resolve')

    parser.add_argument('--data_file', type=str, default='../../data/twitter_data.csv', required=False,
                        help='train data')

    parser.add_argument('--max_text_len', type=int, default=60, required=False,
                        help='maximum length of text')

    parser.add_argument('--model_type', type=str, default='decoder', required=False,
                        help='whether model - decoder or encoder-decoder')
    parser.add_argument('--model_name', type=str, default='gpt2', required=False,
                        help='pretrained model name from huggingface')

    parser.add_argument('--epochs', type=int, default=50, required=False,
                        help='number of epochs')
    parser.add_argument('--lr', type=float, default=0.00003, required=False,
                        help='learning rate')
    parser.add_argument('--early_stopping_rounds', type=int, default=10, required=False,
                        help='number of epochs for early stopping')
    parser.add_argument('--lr_schedule_round', type=int, default=30, required=False,
                        help='number of epochs for learning rate scheduling')

    parser.add_argument('--train_batch_size', type=int, default=4, required=False,
                        help='train batch size')
    parser.add_argument('--eval_batch_size', type=int, default=4, required=False,
                        help='eval batch size')

    parser.add_argument('--model_save_path', type=str, default='./models/', required=False,
                        help='model save path')

    parser.add_argument('--wandb_logging', action='store_true',
                        help='wandb logging needed')
    parser.add_argument('--wandb_project_name', type=str, default='CodeMixed Generation', required=False,
                        help='wandb project name')

    parser.add_argument('--seed', type=int, default=42, required=False,
                        help='seed')


    args = parser.parse_args()
    print (args)

    df = pd.read_csv(args.data_file,sep='\t',lineterminator='\n').dropna(subset=['text']).reset_index(drop=True)
    df = df[df.base_language == 'hin']
    
    if 'author' not in df.columns:
        df['author'] = df['author_id'].copy()

    if 'cid' in df.columns:
        authors = df.groupby(['author'])['cid'].nunique().reset_index()
        authors = authors[authors.cid > 2][['author']]
    else:
        authors = df.groupby(['author'])['id'].nunique().reset_index()
        authors = authors[authors.id > 2][['author']]

    df = pd.merge(df, authors, how='inner')
    df['text'] = df.text.apply(lambda x: remove_emojis(remove_html(remove_email(clean_tweets(x)))).lower())
    df = df.sort_values(['author', 'posted_time_in_years'],ascending=[True, False]).reset_index(drop=True)
    #df = df.sample(frac=1).reset_index(drop=True)
    
    kf = StratifiedKFold(n_splits=2, shuffle=False)
    for train_index, test_index in kf.split(df.text, df.author):
        break

    train_df = df.iloc[train_index]
    val_df = df.iloc[test_index].reset_index(drop=True)

    kf2 = StratifiedKFold(n_splits=2, shuffle=True)
    for val_index, test_index in kf2.split(val_df.text, val_df.author):
        break

    test_df = val_df.iloc[test_index]
    val_df = val_df.iloc[val_index]

    #CMI_gt, M_index_gt, B_gt, LE_gt, _ = calculate_metrics(val_df.text.values)
    #print ("CMI:{}, M Index:{}, Burstiness:{}, Entropy:{}".format(CMI_gt, M_index_gt, B_gt, LE_gt))

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    
    #if tokenizer.pad_token_id is None:
    #    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    if args.model_type == 'decoder':
        #if 'mbart' not in args.model_name.lower():
        #    model = AutoModelWithLMHead.from_pretrained(args.model_name)
        if 'bloom' in args.model_name.lower():
            model = AutoModelForCausalLM.from_pretrained(args.model_name)
        elif 'llama' in args.model_name.lower():
            model = AutoModelForCausalLM.from_pretrained(args.model_name, load_in_8bit=True, device_map="auto")
            #peft_config = LoraConfig(
            #    r=16,
            #    lora_alpha=16,
            #    target_modules=["query", "value"],
            #    lora_dropout=0.1,
            #    bias="none",
            #)
            #model = get_peft_model(model, peft_config)
        else:
            model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    else:
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(args.model_name, args.model_name)
        model.config.decoder_start_token_id = tokenizer.cls_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.vocab_size = model.config.decoder.vocab_size

    ll = LabelEncoder()
    train_df['author_id'] = ll.fit_transform(train_df.author.values.reshape(-1,1))
    val_df['author_id'] = ll.transform(val_df.author.values.reshape(-1,1))
    test_df['author_id'] = ll.transform(test_df.author.values.reshape(-1,1))

    #train_df = train_df.iloc[:250]
    #val_df = val_df.iloc[:250]

    train_dataset = PreTrainedTransformerDataset(train_df.text.values.tolist(), \
                                  tokenizer=tokenizer, MAX_LEN=args.max_text_len)
    quiz_dataset = PreTrainedTransformerDataset(val_df.text.values.tolist(), \
                                  tokenizer=tokenizer, MAX_LEN=args.max_text_len)
    val_dataset = PreTrainedTransformerDataset(test_df.text.values.tolist(), \
                                  tokenizer=tokenizer, MAX_LEN=args.max_text_len)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=False)

    quiz_loader = torch.utils.data.DataLoader(
        quiz_dataset, batch_size=args.train_batch_size, shuffle=False)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.eval_batch_size, shuffle=True)

    print ("Total number of parameters={}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    now = int(datetime.now().timestamp())
    data_name = os.path.basename(args.data_file).split('.')[0]

    model_checkpoint_name = "{}_{}_{}.pth".format(model.__class__.__name__, data_name, str(now))

    if args.wandb_logging == True:
        config = vars(args)
        if args.model_type == 'decoder':
            config['model_name'] = model.__class__.__name__
        else:
            config['model_name'] = args.model_name
        config['model_checkpoint'] = model_checkpoint_name
        wandb.login()
        wandb.init(project=args.wandb_project_name,config=config)
        artifact = wandb.Artifact('Model', type='model')
        wandb.watch(model, log_freq=100)
    
    #CMI, M_index, B, LE, repeat_index = calculate_metrics(test_df.text.values)

    #if args.wandb_logging == True:
    #    wandb.log({"text_CMI": CMI})
    #    wandb.log({"text_M_index": M_index})
    #    wandb.log({"text_Burstiness": B})
    #    wandb.log({"text_Entropy": LE})

    model, best_validation_texts = pretrained_model_trainer.trainer(args, model, train_loader, quiz_loader, val_loader, model_checkpoint_name)

    #model = torch.load(os.path.join(args.model_save_path, model_checkpoint_name))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.eval()
    '''
    all_preds = []

    for batch in tqdm(val_loader):
        src = batch['input_ids'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=src, labels=src)
            total_loss, preds = outputs.loss, outputs.logits

        all_preds.append(preds.argmax(-1).detach().cpu().numpy())
    all_preds = np.concatenate(all_preds, 0)

    all_val_texts = []
    for i in range(all_preds.shape[0]):
        text = tokenizer.decode(all_preds[i])
        for token in tokenizer.special_tokens_map.values():
            text = text.replace(str(token), '')

        all_val_texts.append(text.strip())

    CMI, M_index, B, LE, repeat_index = calculate_metrics(all_val_texts)

    if args.wandb_logging == True:
        wandb.log({"validation_CMI": CMI})
        wandb.log({"validation_M_index": M_index})
        wandb.log({"validation_Burstiness": B})
        wandb.log({"validation_Entropy": LE})
    '''
    ####################### generation ######################
    val_df2 = val_df.drop_duplicates(subset=['author_id'])
    val_df2 = val_df2.rename(columns={'text':'old_text'})
    best_validation_texts = np.concatenate(best_validation_texts,0)
    best_validation_texts = [tokenizer.decode(i).strip() for i in best_validation_texts]
    
    best_validation_texts2 = []
    for text in best_validation_texts:
        for token in tokenizer.special_tokens_map.values():
            text = text.replace(str(token), '')
        best_validation_texts2.append(text)

    #test_df2 = pd.merge(test_df, val_df2[['author_id','old_text']], how='inner').drop_duplicates(['author_id'])[['author_id','text','old_text']]
    test_df2 = pd.merge(test_df, val_df2[['author_id','old_text']], how='left')[['author_id','text','old_text']]
    test_df2['reconstructed_text'] = best_validation_texts2
    test_df2 = test_df2.dropna().reset_index(drop=True)

    all_generated_texts = []

    for i in tqdm(range(test_df2.shape[0])):
        input_prompt = " ".join(test_df2.text.iloc[i].split()[:1])
        input_tokens = tokenizer(input_prompt, return_tensors="pt")["input_ids"].to("cuda")
        with torch.cuda.amp.autocast():
            generation_output = model.generate(
                input_ids=input_tokens,
                max_new_tokens=40,
                do_sample=True,
                top_k=10,
                top_p=0.9,
                temperature=0.3,
                repetition_penalty=1.15,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id,
            )
        op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
        all_generated_texts.append(op)        

        '''
        if tokenizer.eos_token is not None:
            #print ("Inside first loop")
            all_generated_texts.append(generate_pretrained_transformer(model, tokenizer, \
                                                    " ".join(test_df2.text.iloc[i].split()[:1]), " ".join(test_df2.text.iloc[i].split()[:1]), \
                                                    eos_token=tokenizer.eos_token, \
                                                    all_special_tokens = [str(l) for l in list(tokenizer.special_tokens_map.values())]))
        else:
            all_generated_texts.append(generate_pretrained_transformer(model, tokenizer, \
                                                    " ".join(test_df2.text.iloc[i].split()[:1]), " ".join(test_df2.text.iloc[i].split()[:1]), \
                                                    eos_token=tokenizer.sep_token, \
                                                    all_special_tokens = [str(l) for l in list(tokenizer.special_tokens_map.values())]))
        '''

    test_df2['generated_text_coldstart'] = all_generated_texts

    print (test_df2['generated_text_coldstart'])

    CMI, M_index, B, LE, repeat_index = calculate_metrics(test_df2.generated_text_coldstart.values)

    if args.wandb_logging == True:
        #wandb.log({"CS_test_CMI": CMI})
        wandb.log({"CS_test_M_index": M_index})
        wandb.log({"CS_test_Burstiness": B})
        wandb.log({"CS_test_Entropy": LE})

    '''
    all_generated_texts2 = []

    for i in tqdm(range(test_df2.shape[0])):
        if tokenizer.eos_token is not None:
            all_generated_texts2.append(generate_pretrained_transformer(model, tokenizer, \
                                                    test_df2.old_text.iloc[i],  " ".join(test_df2.text.iloc[i].split()[:10]), \
                                                    eos_token=tokenizer.eos_token, \
                                                    all_special_tokens = list(tokenizer.special_tokens_map.values())))
        else:
            all_generated_texts2.append(generate_pretrained_transformer(model, tokenizer, \
                                                    test_df2.old_text.iloc[i],  " ".join(test_df2.text.iloc[i].split()[:10]), \
                                                    eos_token=tokenizer.sep_token, \
                                                    all_special_tokens = list(tokenizer.special_tokens_map.values())))

    test_df2['generated_text'] = all_generated_texts2

    CMI, M_index, B, LE, repeat_index = calculate_metrics(test_df2.generated_text.values)
    
    if args.wandb_logging == True:
        #wandb.log({"test_CMI": CMI})
        wandb.log({"test_M_index": M_index})
        wandb.log({"test_Burstiness": B})
        wandb.log({"test_Entropy": LE})
    '''

    tokenizer_ = AutoTokenizer.from_pretrained("sagorsarker/codeswitch-hineng-lid-lince")
    model_ = AutoModelForTokenClassification.from_pretrained("sagorsarker/codeswitch-hineng-lid-lince")
    lid_model = pipeline('ner', model=model_, tokenizer=tokenizer_)
    
    CMI = []
    #CMI_old = []

    for i in tqdm(range(test_df2.shape[0])):
        CMI.append(calculate_CMI(test_df2.generated_text_coldstart.iloc[i], lid_model))
        #CMI_old.append(calculate_CMI(test_df2.text.iloc[i], lid_model))

    #test_df2['CMI_actual'] = CMI_old
    test_df2['CMI_generated'] = CMI

    #CMI = []
    #CMI_old = []

    #for i in tqdm(range(test_df2.shape[0])):
    #    CMI.append(calculate_CMI(test_df2.generated_text.iloc[i], lid_model))
    #    #CMI_old.append(calculate_CMI(test_df2.text.iloc[i], lid_model))

    #test_df2['CMI_actual'] = CMI_old
    #test_df2['CMI_generated_prompted'] = CMI

    #if args.wandb_logging == True:
    #    wandb.log({"CS_test_CMI": test_df2['CMI_generated'].mean()})
    #    wandb.log({"test_CMI": test_df2['CMI_generated_prompted'].mean()})
     
    tokenizer_ = AutoTokenizer.from_pretrained("sagorsarker/codeswitch-hineng-lid-lince")
    model_ = AutoModelForTokenClassification.from_pretrained("sagorsarker/codeswitch-hineng-lid-lince")

    #test_df2 = test_df2.dropna(subset=['old_text','text','generated_text','generated_text_coldstart']).reset_index(drop=True)
    test_df2 = test_df2.dropna(subset=['old_text','text','generated_text_coldstart']).reset_index(drop=True)

    rogue1 = []
    rogueL = []
    blues = []

    for i in range(test_df2.shape[0]):
        lid1 = language_identification(lid_model, test_df2.text.iloc[i])
        lid2 = language_identification(lid_model, test_df2.generated_text_coldstart.iloc[i])

        lid1 = " ".join(list(lid1.values()))
        lid2 = " ".join(list(lid2.values()))

        scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
        scores = scorer.score(lid1, lid2)

        rogue1.append(scores['rouge1'].fmeasure)
        rogueL.append(scores['rougeL'].fmeasure)

        try:
            blues.append(sentence_bleu([lid1.split()], lid2.split()))
        except:
            blues.append(0)

    test_df2['rouge1_cs'] = rogue1
    test_df2['rougeL_cs'] = rogueL
    test_df2['bleu_cs'] = blues

    '''
    rogue1 = []
    rogueL = []
    blues = []

    for i in range(test_df2.shape[0]):
        lid1 = language_identification(lid_model, test_df2.text.iloc[i])
        lid2 = language_identification(lid_model, test_df2.generated_text.iloc[i])

        lid1 = " ".join(list(lid1.values()))
        lid2 = " ".join(list(lid2.values()))

        scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
        scores = scorer.score(lid1, lid2)

        rogue1.append(scores['rouge1'].fmeasure)
        rogueL.append(scores['rougeL'].fmeasure)

        try:
            blues.append(sentence_bleu([lid1.split()], lid2.split()))
        except:
            blues.append(0)

    test_df2['rouge1'] = rogue1
    test_df2['rougeL'] = rogueL
    test_df2['bleu'] = blues
    '''

    if args.wandb_logging == True:
        wandb.log({"CS_rogue1": test_df2['rouge1_cs'].mean()})
        wandb.log({"CS_rogueL": test_df2['rougeL_cs'].mean()})
        wandb.log({"CS_bleu": test_df2['bleu_cs'].mean()})
        #wandb.log({"rogue1": test_df2['rouge1'].mean()})
        #wandb.log({"rogueL": test_df2['rougeL'].mean()})
        #wandb.log({"bleu": test_df2['bleu'].mean()})
    test_df2.to_csv(os.path.join(args.model_save_path, 'generated_texts_{}.csv'.format(now)), sep='\t', index=False)
    if args.wandb_logging == True:
        wandb.log({"generated_text": os.path.join(args.model_save_path, 'generated_texts_{}.csv'.format(now))})