from __future__ import print_function
import os
import time
import datetime
import argparse
import numpy
import random
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

import time
import data
import models
from utils import *
import options
import glob
import shutil
import logging
import youtokentome as yttm
from onlineoptim import SGDHD,SGDTLR
from tools import Dictionary, Corpus, repackage_hidden, batchify, get_batch
from torch.autograd import Variable
from collections import Counter

def create_path(path):
	if os.path.isdir(path) is False:
		os.makedirs(path)
	else:
		if len(glob.glob(f"{path}/*.log", recursive=True)) > 0:
			file = glob.glob(f"{path}/*.log", recursive=True)[0]
			with open(file, 'r') as f:
				text = f.read()
			if "Step: 99999/100000" in text:
				print("File exists")
				quit()
			else:
				shutil.rmtree(path)
				os.makedirs(path)
				print("Removing old files")
		else:
			shutil.rmtree(path)
			os.makedirs(path)
			print("Removing old files")
				
	return 

def print_args(Args, args, logging):
    ''' Print and save argsions
        It will print both current argsions and default values(if different).
        It will save argsions into a text file / [checkpoints_dir] / args.txt
    '''
    message = str(datetime.datetime.now())
    message += '\n----------------- Options ---------------\n'
    for k, v in sorted(vars(args).items()):
        comment = ''
        default = Args.parser.get_default(k)
        if v != default:
            comment = '\t[default: %s]' % str(default)
        message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
    message += '----------------- End -------------------'
    logging.info(message)
    
    file_name = os.path.join(args.cp_dir, 'args.txt')
    with open(file_name, 'wt') as args_file:
        args_file.write(message)
        args_file.write('\n')

def main():
    Args = options.Options()
    args = Args.parse()
    device = torch.device("cuda" if args.use_cuda else "cpu")
    print('model will be trained on', device)
    torch.manual_seed(args.seed)
    numpy.random.seed(args.seed)
    random.seed(args.seed)


    # initialize directory
    args.cp_dir = f"checkpoints/{args.optimizer}_clr/run_ms_{args.lr}_{args.base_lr}_{args.step_size_up}_{args.seed}/"
    create_path(args.cp_dir)
    for file in glob.glob("**/*.py", recursive=True):
        if "checkpoints" in file or "data" in file or "results" in file:
            continue
        os.makedirs(os.path.dirname(f"{args.cp_dir}/codes/{file}"), exist_ok=True)
        shutil.copy(file, f"{args.cp_dir}/codes/{file}")
    writer = SummaryWriter(log_dir=args.cp_dir)

	# initialize logging
    train_log = os.path.join(args.cp_dir, time.strftime("%Y%m%d-%H%M%S") + '.log')
    logging.basicConfig(
		format="%(name)s: %(message)s",
		level="INFO",
		handlers=[
			logging.FileHandler(train_log),
			logging.StreamHandler()
		]
	)
    print_args(Args, args, logging)


    bpe_model = yttm.BPE(model="data/bpe.32000.model")
    logging.info(f"Vocab length: {bpe_model.vocab_size()}")

    train_loader = data.load("data/",
						   split='train',
						   batch_size=args.batch_size,
						   bpe_model=bpe_model,
						   workers=4)
    val_loader = data.load("data/",
					     split='dev',
					     batch_size=args.batch_size,
					     shuffle=False,
					     bpe_model=bpe_model)
    test_loader = data.load("data/",
					     split='test',
					     batch_size=1,
					     shuffle=False,
					     bpe_model=bpe_model)
    
    model = models.Seq2SeqTransformer(
		num_encoder_layers=3,
		num_decoder_layers=3,
		emb_size=512,
		nhead=8,
		vocab_size=bpe_model.vocab_size(),
		dim_feedforward=512,
		dropout=args.dropout_rate)
    model = model.to(device)
    model.initialize()
    
    optimizer = torch.optim.Adam(
                                model.parameters(),
                                lr=args.lr,
                                betas=(0.9, 0.999), 
                                eps=1e-08,
                                weight_decay=args.wd,
                                amsgrad=False) 
    optimizer.param_groups[0]['weight_decay'] = args.wd

    if args.is_preset_ld:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
														   mode='min',
														   factor=0.5,
														   patience=args.patience,
														   threshold=0.01,
														   min_lr=1e-6)
    elif args.is_cos_ld:
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1, verbose=False)
    elif args.is_linear_ld:
        lr_lambda = lambda x: 1-(x-1)/(args.epochs-1)
        lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
    elif args.is_cyclic_ld:
        # step_size_up = args.clr_epoch*batches_per_epoch
        step_size_up = args.step_size_up
        lr_scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.base_lr, max_lr=args.lr, gamma=args.clr_gamma, step_size_up=step_size_up, mode=args.clr_type, cycle_momentum=False)
    elif args.is_period_ld:
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=args.ld_period, gamma=args.ld_factor, last_epoch=-1)
    else:
        lr_scheduler = None

    trainloss = AverageMeter()
    best_loss = float('inf')
    for epoch in range(1, args.epochs+1):
        model.train()
        acc_loss = 0
        acc_data_points = 0
        for batch_idx, batch in enumerate(train_loader):
            inputs, targets = (b.to(device) for b in batch)
            targets_input = targets[:-1, :]
            optimizer.zero_grad()

            inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask = create_mask(inputs, targets_input, device)
            outputs = model(inputs, targets_input, inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

            if args.is_auto_ld:
                if args.optimizer == "adam_hd1":
                    model.update_lr_HD(args, device, optimizer)
                elif args.optimizer == "adam_hd2":
                    model.update_lr_HD2(args, device, optimizer)
                elif args.optimizer == "adam_tlr1":
                    model.update_lr_tlr(args, device, optimizer)
                elif args.optimizer == "adam_tlr2":
                    model.update_lr_tlr2(args, device, optimizer)
                else:
                    model.update_lr(args, device, optimizer)
            # print(1)

            targets_out = targets[1:, :]
            criterion = torch.nn.CrossEntropyLoss(ignore_index=1)
            batch_loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets_out.reshape(-1))
            batch_loss.backward()

            optimizer.step()
            if args.is_cyclic_ld:
                model.step_counter += 1
                lr_scheduler.step()
                # print(2)

            trainloss.update(batch_loss.item(), inputs.shape[1])


            writer.add_scalar('Train/train_loss', batch_loss.item(), model.step_counter)
            writer.add_scalar('Params/lr', optimizer.param_groups[0]['lr'], model.step_counter)
            
            if batch_idx % args.log_interval == 0:
                logging.info('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    model.step_counter, batch_idx * inputs.shape[1], len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), batch_loss.item()))
            
            #validate and test every 10000 iterations
            # print(model.step_counter)
            if model.step_counter % args.val_freq == 0:
                model.eval()
                valloss = AverageMeter()
                
                with torch.no_grad():
                    for batch_idx, batch in enumerate(val_loader):
                        inputs, targets = (b.to(device) for b in batch)
                        targets_input = targets[:-1, :]
                        
                        inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask = create_mask(inputs, targets_input, device)
                        outputs = model(inputs, targets_input, inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                        
                        targets_out = targets[1:, :]
                        batch_loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets_out.reshape(-1))
                        
                        valloss.update(batch_loss.item(), inputs.shape[1])

                writer.add_scalar('Val/val_loss', valloss.avg, model.step_counter)
                logging.info('Step: {0}/{1}, Train_Loss = {2}, Val_Loss = {3}'.format(model.step_counter, args.ts, trainloss.avg, valloss.avg))
                
                if valloss.avg < best_loss:
                    torch.save(model.state_dict(), f"{args.cp_dir}/best_model.pth.tar")
                    best_loss = valloss.avg
                
                if args.is_preset_ld:
                    lr_scheduler.step(valloss.avg)

                # Test BLUE score
                with torch.no_grad():
                    score = compute_bleu(model, test_loader, bpe_model, device)
                writer.add_scalar('Test/bleu_score', score, model.step_counter)
                logging.info('Step: {0}/{1}, Bleu_score = {2}'.format(model.step_counter, args.ts, score))

                model.train()
                trainloss = AverageMeter()





    # train_losses, val_losses = [], []
    # for epoch in range(1, args.epochs + 1):
    #     train_loss = model.train_one_epoch(args, device, train_loader, optimizer, lr_scheduler, epoch)
    #     val_loss = model.test_one_epoch(device, test_loader)
    #     train_losses += [train_loss]
    #     val_losses  += [val_loss]

    #     epoch_period = (time.time() - epoch_begin_time)/(epoch* 3600)
    #     print('--(epoch %d, %.2fh/%.2fh)'%(epoch, epoch_period, epoch_period*(args.epochs-epoch-1)),
    #             'Learning rate is', model.optim_lrs[-1], '\n')
    #     if len(model.angle_velocities) > 0:
    #         print(epoch, 'my angle velocity is', model.angle_velocities[-1])
    #     if lr_scheduler != None and not args.is_cyclic_ld:
    #         lr_scheduler.step()
        
    #     if args.save_model and epoch % 100 == 0:
    #         torch.save(model.state_dict(), "%s/%s_%s_%d.pt"%(args.expr_dir, args.dataset, args.model, epoch))
    #     if args.save_model:
    #         torch.save({
    #             'lrs': model.optim_lrs,
    #             'train_losses': train_losses,
    #             'val_losses': test_losses,
    #             'angle_velocities':model.angle_velocities,
    #             'angle_velocities_smooth':model.angle_velocities_smooth,
    #         }, '%s/%s_%s_stat.pt'%(args.expr_dir, args.dataset, args.model))
    #     print(epoch, "ends")

if __name__ == '__main__':
    main()