import random
#from tkinter import LAST
import tqdm
import numpy as np

import torch
import torch.optim as optim
from AutoRegressiveWrapper import AutoRegressiveWrapper
from LIPARTransfomer import LIPARTransformer
import Utils
import sys
import math
import os
from transformers import AutoTokenizer  # pip install transformers
import prep_text8
from data_utils import Corpus

# ------------------------------------------------------------------------ 
# Acknowledgement: some parts of code are adapted from Andrej Karpathy and Phil Wang
#-------------------------------------------------------------------------

# ------Architecture constants------------
DO_WORD_LEVEL_MODELING = True # set to false for character level, true for word
NUM_BATCHES = 300000 # int(2e6)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 2
LEARNING_RATE = 5e-5  # 3e-4 first 100000
VALIDATE_EVERY  = 5000
GENERATE_EVERY  = 10000
GENERATE_LENGTH = 256
SEQ_LENGTH = 1024 # was 1024
RESUME_TRAINING = False # set to false to start training from beginning
LAST_BEST_PERPLEXITY = 999#21.76  

EMBEDDING_SIZE = 768
NUM_HEADS = 6#12
NUM_LAYERS = 6#12
SEGMENT_LEN = 256  # 
NUM_SEGMENTS = 4  # This is the number of segments the SEG_LENGTH is divided into
# perplex = 31.44 190K for 768 Embedding, 6 Layers, 6 heads, segment_len = 256, seq_len=1024
SAVE_MODEL_NAME = 'checkpoint/LIPAR_1024_768Latent_6_6_256_model.pt'
#---------------------------
tokenizer_word = AutoTokenizer.from_pretrained("bert-base-cased",truncation=True, max_length=SEQ_LENGTH) # for word level modeling

#following functions are for character level modeling----------
def decode_token_char(token): # convert token to character
    return str(chr(max(32, token)))

def decode_tokens_char(tokens): # convert sequence of characters to tokens
    return ''.join(list(map(decode_token_char, tokens)))
#------------------------------------------------------------------------

def decode_tokens_word(tokens): # convert token to word - for word level modeling
    return tokenizer_word.decode(tokens)

def count_parameters(model): # count number of trainable parameters in the model
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def configure_optimizers(mymodel):
    """
    We are separating out all parameters of the model into two buckets: those that will experience
    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
    Retur the PyTorch optimizer object.
    """

    # separate out parameters that will experience regularizing weight decay
    # and those that will not
    decay = set()
    no_decay = set()
    whitelist_weight_modules = (torch.nn.Linear, )
    blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
    for mn, m in mymodel.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
            # because named_modules and named_parameters are recursive
            # we will see the same tensors p many many times. but doing it this way
            # allows us to know which parent module any tensor p belongs to...
            if pn.endswith('bias'):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                # weights of whitelist modules will be weight decayed
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in mymodel.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
    assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                % (str(param_dict.keys() - union_params), )

    # create the pytorch optimizer object
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.1},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9,0.95))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_BATCHES)
    return optimizer, scheduler

def compute_valid_test_perplexities(model, valid_set, test_set,device,max_len=SEQ_LENGTH):
    ppl_valid = compute_perplexity_huggingface(model, valid_set, device, max_len=SEQ_LENGTH)
    ppl_test = compute_perplexity_huggingface(model, test_set, device, max_len=SEQ_LENGTH)
    return ppl_valid, ppl_test


def compute_perplexity_huggingface(model,test_set,device,max_len=SEQ_LENGTH):
    model.eval()
    global LAST_BEST_PERPLEXITY
    stride = 512
    encodings = test_set.data
    encodings = encodings.view(1,encodings.size(0)*encodings.size(1))
    total_data_len = encodings.size(1)
    nlls = []
    prev_end_loc = 0
    count = 0
    for begin_loc in tqdm.tqdm(range(0, total_data_len, stride)):
        #end_loc = min(begin_loc + max_len+1, total_data_len+1)
        end_loc = min(begin_loc + max_len+1, total_data_len)
        if (end_loc - begin_loc) < (max_len+1):
            continue
        #trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings[:, begin_loc:end_loc].to(device)
        if input_ids.shape[-1] < (max_len+1):
            continue
        target_ids = input_ids.clone()
        # target_ids[:, :-trg_len] = -100
        count = count + 1
        # if (count == 400):
        #    break
        with torch.no_grad():
            #outputs = model(input_ids, labels=target_ids) # from hugging face
            #loss = model(input_ids)
            loss = model(target_ids)
            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            #neg_log_likelihood = outputs.loss # from hugging face

        #nlls.append(neg_log_likelihood) # from hugging face
        nlls.append(loss)

        #prev_end_loc = end_loc
        if end_loc == total_data_len:
            break
    print(count)
    ppl = torch.exp(torch.stack(nlls).mean())
    # best_found = False
    # if LAST_BEST_PERPLEXITY == 999:
    #     LAST_BEST_PERPLEXITY = ppl
    # else:
    #     if ppl < LAST_BEST_PERPLEXITY:
    #         LAST_BEST_PERPLEXITY = ppl
    #         best_found = True
    #         # save the best model

    # print("-----------Perplexity = ", ppl, "---------loss=",torch.stack(nlls).mean())
    #return best_found
    return ppl

# def save_model(model, i, optim, fname):
#     # ---------save the latest model---------
#     print("----------saving model-----------------")
#     checkpoint_data = {
#     'epoch': i,
#     'state_dict': model.state_dict(),
#     'optimizer': optim.state_dict()
#     }
#     ckpt_path = os.path.join("checkpoint/" + fname) #transAM_WK_model_best.pt")
#     torch.save(checkpoint_data, ckpt_path)
#     # revert model to training mode
#     model.train()
def save_model(model, i, optim, scheduler, fname, last_best_valid_perplexity,
               last_best_test_perplexity):
    # ---------save the latest model---------
    print("----------saving model-----------------")
    checkpoint_data = {
    'epoch': i,
    'state_dict': model.state_dict(),
    'optimizer': optim.state_dict(),
    'lr_sched': scheduler.state_dict(),
    'last_best_valid_perplexity': last_best_valid_perplexity,
    'last_best_test_perplexity': last_best_test_perplexity
    }
    #ckpt_path = os.path.join("checkpoint/" + fname) #transAM_WK_model_best.pt")
    ckpt_path = fname
    torch.save(checkpoint_data, ckpt_path)
    # revert model to training mode
    model.train()

def f1():
    print(LAST_BEST_PERPLEXITY)

def main():
    LAST_BEST_VALID_PERPLEXITY = 999#21.76 
    LAST_BEST_TEST_PERPLEXITY = 999
    #f1()
    #prep_text8.prepare_text8() # prepare text8 data

    #NUM_TOKENS = 256 # for character level modeling
    NUM_TOKENS = 204   # based on transformer XL code for enwik8
    if DO_WORD_LEVEL_MODELING == True:
        NUM_TOKENS = 28996 # bert-base_cased for wikitext-103
    dim_head = int(EMBEDDING_SIZE/NUM_HEADS)
    LIPAR_model = LIPARTransformer(
        dim = EMBEDDING_SIZE, 
        #num_tokens = 28996, # for bert-base_cased for wikitext-103, 
        num_tokens = NUM_TOKENS,   
        num_layers = NUM_LAYERS, 
        heads = NUM_HEADS, 
        sequence_len = SEQ_LENGTH,
        segment_len = SEGMENT_LEN,
        num_segments = NUM_SEGMENTS,
        ff_dropout=0.05,
        attn_dropout = 0.05
     ).cuda()

    model = AutoRegressiveWrapper(LIPAR_model, segment_len=SEGMENT_LEN)
    model.cuda()
    pcount = count_parameters(model)
    print("count of parameters in the model = ", pcount/1e6, " million")

    if DO_WORD_LEVEL_MODELING == True:
        #train_loader, val_loader, test_loader, val_dataset, test_dataset = Utils.get_loaders_wikitext_103(tokenizer_word, SEQ_LENGTH, BATCH_SIZE)
        train_loader, val_loader, test_loader, val_dataset, test_dataset = Utils.get_loaders_wikitext_103(tokenizer_word, SEQ_LENGTH, BATCH_SIZE)
    else: # char level modeling
        # train_loader, val_loader, test_loader, val_dataset = Utils.get_loaders_enwiki8(SEQ_LENGTH, BATCH_SIZE)
        train_loader, val_loader, test_loader, val_dataset = Utils.get_loaders_enwiki8_basedon_transformerXL(SEQ_LENGTH, BATCH_SIZE)
        #train_loader, val_loader, test_loader, val_dataset = Utils.get_loaders_text8(SEQ_LENGTH, BATCH_SIZE)

    #optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)      # optimizer
    optim, sched = configure_optimizers(model)

    # --------training---------
    if RESUME_TRAINING == False:
        start = 0
    else:
        checkpoint_data = torch.load(SAVE_MODEL_NAME)
        #checkpoint_data = torch.load('checkpoint/PARBase_2048_768Latent_6_12_model.pt')
        model.load_state_dict(checkpoint_data['state_dict'])
        optim.load_state_dict(checkpoint_data['optimizer'])
        for param_group in optim.param_groups:  # if lr needs to be changed
            param_group['lr'] = LEARNING_RATE
        start = checkpoint_data['epoch']
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, NUM_BATCHES, last_epoch=checkpoint_data["epoch"]-1)
        sched.load_state_dict(checkpoint_data['lr_sched'])
        optim.param_groups[0]["lr"] = LEARNING_RATE
        print('lr=',optim.param_groups[0]["lr"])
        optim.param_groups[0]["lr"] = 2e-5
        LAST_BEST_VALID_PERPLEXITY = checkpoint_data['last_best_valid_perplexity']
        LAST_BEST_TEST_PERPLEXITY = checkpoint_data['last_best_test_perplexity']
        a = 5
    for i in tqdm.tqdm(range(start,NUM_BATCHES), mininterval = 10., desc = 'training'):
        model.train()
        total_loss = 0
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            loss = model(next(train_loader))
            loss.backward()
        if (i%100 == 0):
            print(f'training loss: {loss.item()} -- iteration = {i}')

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        optim.zero_grad()
        optim.param_groups[0]["lr"] = 2e-5
        
        if (i% VALIDATE_EVERY == 0) and (DO_WORD_LEVEL_MODELING == True):
            valid_perplexity, test_perplexity = compute_valid_test_perplexities(model,val_dataset,test_dataset, 'cuda')
            if valid_perplexity < LAST_BEST_VALID_PERPLEXITY:
                LAST_BEST_VALID_PERPLEXITY = valid_perplexity
                print('saving best validation model, valid_perplex=',valid_perplexity, ' test perplexity=',test_perplexity)
                save_model(model,i,optim,sched,SAVE_MODEL_NAME,LAST_BEST_VALID_PERPLEXITY, LAST_BEST_TEST_PERPLEXITY) 
            if test_perplexity < LAST_BEST_TEST_PERPLEXITY:
                LAST_BEST_TEST_PERPLEXITY = test_perplexity
            print('-------------BEST test perplexity=',LAST_BEST_TEST_PERPLEXITY, ' best valid Perplex=',LAST_BEST_VALID_PERPLEXITY, ' current valid perplex=',valid_perplexity, 
                  ' current test Perplex=', test_perplexity)


        # if ((i+0) % VALIDATE_EVERY == 0) and (DO_WORD_LEVEL_MODELING == False):
        #     model.eval()
        #     total_len2 = 0
        #     total_loss2 = 0
        #     val_count = 50  # 1000 -number of validations to compute average BPC
        #     with torch.no_grad():
        #         for v in range(0,val_count):
        #             zz = next(test_loader)
        #             loss = model(next(test_loader))
        #             total_loss += loss.item()
        #             loss_m = loss.mean()
        #             total_loss2 += SEQ_LENGTH * loss_m.item() #loss.float().item() #seq_len
        #             total_len2 += SEQ_LENGTH
        #         print(f'----------validation loss: {total_loss/val_count}')
        #         print(f'Perplexity : {math.exp(total_loss/val_count)}, BPC: {total_loss/val_count*np.log2(2.7173)}')
        #         bpc2 = (total_loss2/total_len2)/math.log(2)
        #         print("BPC 2 = ", bpc2)
        #         total_loss = 0

        # if (i+1) % GENERATE_EVERY == 0:  
        #     model.eval()
        #     inp = random.choice(val_dataset)[:-1]
        #     if DO_WORD_LEVEL_MODELING == True:
        #         input_start_sequence = decode_tokens_word(inp)
        #     else:
        #         input_start_sequence = decode_tokens_char(inp)
        #     print("----------start input------------------")
        #     print(f'%s \n\n', (input_start_sequence))
        #     print("----------end of start input-----------")
        #     sample = model.generate(inp, GENERATE_LENGTH)
        #     if DO_WORD_LEVEL_MODELING == True:
        #         output_str = decode_tokens_word(sample)
        #     else:
        #         output_str = decode_tokens_char(sample)
        #     print("----------generated output-------------")
        #     print(output_str)
        #     print("----------end generated output---------")
        # if i % 1000 == 0:  #0 
        #     save_model(model,i,optim,"transAM_WK_model.pt")
        #     model.train()

if __name__ == "__main__":
    sys.exit(int(main() or 0))

