import os.path
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import Loss
import Layers as L
import Encoder_Decoder as ED
import AMP_Dataset as Data
import time
import random
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("The current device is: ", device)


# if device == 'cuda':
#        DataLoader={key: tensor.to(device) for key, tensor in DataLoader.items()}

def run_epoch(aepoch, model, opt, lossfun):
    # start = time.time()
    total_tokens = 0
    total_loss = 0
    # tokens = 0
    for i, batch in enumerate(DataLoader):
        src_v = batch['src'].to(device)
        tgt_v = batch['tgt'].to(device)
        src_mask_v = batch['src_mask'].to(device)
        tgt_mask_v = batch['tgt_mask'].to(device)
        #  print(src_v.is_cuda, tgt_v.is_cuda, src_mask_v.is_cuda, tgt_mask_v.is_cuda)
        output = model.forward(src_v, tgt_v, src_mask_v, tgt_mask_v)
        # output = model.forward(batch['src'], batch['tgt'], batch['src_mask'], batch['tgt_mask'])
        output = output.contiguous().view(-1, output.size(-1))
        target = batch['tgt_y'].contiguous().view(-1)
        target = target.to(device)
        loss = lossfun(output, target)
        loss = loss / batch['ntokens'][0]
        loss.backward()
        if opt is not None:
            opt.step()
            opt.optimizer.zero_grad()

        mean_ntokens = torch.mean(batch['ntokens'].type(torch.float))
        loss = loss.data.item() * mean_ntokens
        total_loss += loss.detach().numpy()
        total_tokens += mean_ntokens.numpy()
        # tokens += mean_ntokens.numpy()
        # if i % 50 == 1:
        #     elapsed = time.time() - start
        #     print('epoch step: {}:{} Loss: {}/{} tokens per sec: {}/{}'.format(aepoch, i, loss, batch['ntokens'][0],
        #                                                                        tokens, elapsed))
        #     start = time.time()
        #     tokens = 0
    # return total_loss / total_tokens
    return total_loss, total_tokens


def check_file(path, model_name):
    if not os.path.exists(path):
        os.makedirs(path)
    if path.endswith('/'):
        return path + model_name
    else:
        return path + '/' + model_name


if __name__ == '__main__':
    start_time = time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--epoch', type=int, default=1000)
    parser.add_argument('-q', '--query_data', type=str, default=None)
    parser.add_argument('-v', '--value_data', type=str, default=None)
    parser.add_argument('-b', '--batch_size', type=int, default=1024)
    parser.add_argument('-s', '--seed', type=int, default=37)
    parser.add_argument('-m', '--path_model_output', type=str, default='./models/')
    parser.add_argument('-n', '--model_name', type=str, default='amp_models.pt')
    parser.add_argument('-l', '--layers', type=int, default=2, help='number of layers')
    args = parser.parse_args()

    # settings
    # keep reproducible
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    path_model_output = args.path_model_output
    batch_size = args.batch_size
    epoch_num = args.epoch
    path_model_output = check_file(path_model_output, args.model_name)

    DataLoader, src_vocab, tgt_vocab = Data.BuildDataLoader(batch_size, args.query_data, args.value_data)
    Transformer = ED.make_model(src_vocab=len(src_vocab), tgt_vocab=len(tgt_vocab), N=args.layers, d_model=512,
                                d_ff=2048,
                                h=8)
    Transformer.to(device)
    opt = Loss.get_std_opt(Transformer)
    lossfun = Loss.LabelSmoothingKLDivLoss(size=len(tgt_vocab), padding_idx=0, smoothing=0.1)

    # training
    for epoch in tqdm(range(epoch_num), desc="Epoch", leave=True, dynamic_ncols=True):
        Transformer.train()
        r_loss, r_token = run_epoch(epoch, Transformer, opt, lossfun)
        Transformer.eval()
        e_loss, e_token = run_epoch(epoch, Transformer, opt, lossfun)
        if epoch % 50 == 0 or epoch == (epoch_num - 1):
            print("Epoch: ", epoch)
            print(f'Training results: {r_loss} / {r_token}  => the loss: {r_loss / r_token}')
            print(f'Training results: {e_loss} / {e_token}  => the loss: {e_loss / e_token}')
        if epoch % 100 == 0:
            print('The model is saving in the following path: ')
            print('Path => ', path_model_output)
            torch.save(Transformer.state_dict(), path_model_output)
    print('The model is saving in the following path: ')
    print('Path => ', path_model_output)
    torch.save(Transformer.state_dict(), path_model_output)

    end_time = time.time()
    print("The total training time is : ", end_time - start_time)
