import argparse
import time
from tqdm import tqdm
import os
import os.path as osp
import torch
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DataParallel as DP
from dataloader import HandLanbanDataset
from utils.logger import setup_logger
from utils.validation import validation
from models.loss import CrossEntropyLoss
from models.loss import LabanBiasedloss
from models.loss import HirerarchicalCrossEntropyLoss
from models.loss import AdaptiveWingLoss
from models.HandLabanNet_res152 import ResNet152_CMP
from models.MHLFormer import PreMultiViewAttentionCMP

from utils.distributed_train_util import process_input, process_output
import sys
from torch.utils.tensorboard import SummaryWriter

from torchviz import make_dot
import tensorflow as tf

#import useful utils file and model file
root_dir = os.getcwd()
data_dir = osp.join(root_dir,'data')
log_file_dir = osp.join(root_dir,'logs')
checkpoint_dir = osp.join(data_dir,'checkpoints')
tensorboard_dir = 'tensorboard/'
print(root_dir)
def get_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--world_size', default=4, type=int)
    parser.add_argument('--node_rank', default=0, type=int)
    parser.add_argument('--ip', default='116.111.139.121', type=str)
    parser.add_argument('--port', default='22', type=str)
    parser.add_argument('--data_file_path', default=data_dir, type=str)
    parser.add_argument('--print_iter', default=1, type=int)
    parser.add_argument('--log_file_path', default=osp.join(log_file_dir,'train'), type=str)
    parser.add_argument('--checkpoint_file_path', default=checkpoint_dir, type=str)
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--num_workers', default=64, type=int)
    parser.add_argument('--aug', default=True, type=bool, choices=[True, False])
    parser.add_argument('--clip_step', default=10, type=int)
    parser.add_argument('--optimizer', default='AdamW', type=str, choices=['AdamW','SGD'])
    parser.add_argument('--epochs', default=10, type=int)
    parser.add_argument('--scheduler', default='step', type=str, choices=['cos','step'])
    parser.add_argument('--lr', default=0.00001, type=float)
    parser.add_argument('--backbone', type=str, default = 'ResNet152',required=True)
    parser.add_argument('--model', type=str, default='ResNet152', required=True)
    parser.add_argument('--loss', type=str, default='CrossentropyLoss',required=True)
    parser.add_argument('--split_train', type=str, default='train', required=True)
    parser.add_argument('--split_val', type=str, default='val', required=True)
    
    return parser.parse_args()


def main():
    args = get_args()
    args.nprocs = torch.cuda.device_count()
    local_rank = torch.cuda.current_device()
    is_main_device = (torch.cuda.current_device() == 0)
    cudnn.enabled = True
    cudnn.benchmark = True
    if not osp.exists(args.log_file_path):
       
       os.makedirs(args.log_file_path)
    log_path = osp.join(args.log_file_path, args.model + '_'  +  args.loss + '_epoch_' + str(args.epochs) + '_lr_' + str(args.lr) + '.log')
    tensorboard_file_dir = osp.join(tensorboard_dir, args.model + '_' + args.backbone + '_' +  args.loss)
    if is_main_device:
        logger = setup_logger(output = log_path, name = 'Training')
        logger.info('Start training')
        writer = SummaryWriter(tensorboard_file_dir)
    
    if is_main_device:
        logger.info('Creating train dataset')
    train_dataset = HandLanbanDataset(args.data_file_path, split=args.split_train, aug = args.aug, clip_step = args.clip_step)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers = args.num_workers, shuffle = False)
    
    if is_main_device:
        logger.info('The train dataset is created successfully !')
    elif args.model == 'ResNet152_CMP':
        model = ResNet152_CMP()
    elif args.model == 'PreMultiViewAttentionCMP':
        model = PreMultiViewAttentionCMP()
    elif args.model == 'Debug':
        model = PreMultiViewAttentionCMP()

    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = DP(model, device_ids = [i for i in range(4)]).to(device)
    model = model.to(device)
    if is_main_device:
        logger.info('Model initialization succeeded. ')
    if args.optimizer == 'AdamW':
        optimizer = optim.AdamW([{'params': model.parameters(), 'initial_lr': args.lr}],args.lr)
        if is_main_device:
            logger.info('The parameters of the model are added to the AdamW optimizer. ')
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), args.lr, momentum = 0.9, weight_decay = 1e-4)
        if is_main_device:
            logger.info('The parameters of the model are added to the SGD optimizer')
    if args.scheduler == 'cos':
        train_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max = args.epochs, eta_min = 0)
        if is_main_device:
            logger.info('The learning rate schedule for the optimizer has been set to CosineAnnealingLR.')
    elif args.scheduler == 'step':
        train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones = [40,70,80], gamma = 0.2)
        if is_main_device:
            logger.info('The learning rate schedule for the optimizer has been set to MutiStepLR.')
    if args.loss == 'CrossentropyLoss':
        criterion = CrossEntropyLoss()
    elif args.loss == 'LabanBiasedloss':
        criterion = LabanBiasedloss()
    elif args.loss == 'AdaptiveWingLoss':
        criterion = AdaptiveWingLoss()
    elif args.loss == 'HirerarchicalCrossEntropyLoss':
        criterion = HirerarchicalCrossEntropyLoss(alpha = 0.1)
    if is_main_device:
        logger.info('The loss is: {} loss'.format(args.loss))
        
    best_total_acc = 0.0
    best_horizontal_acc = 0.0
    best_vertical_acc = 0.0
    global_iteration = 0
    for epoch in range(args.epochs):
        start = time.time()
        
        if is_main_device:
            logger.info('Epoch {}/{} started.'.format(epoch+1, args.epochs))
        pbar = tqdm(total = len(train_loader))
        model.train()
        train_scheduler.step(epoch = epoch)
        for iteration, (inputs, targets) in enumerate(train_loader):
            if is_main_device:
                pbar.update()
            inputs = process_input(inputs, local_rank, non_blocking = True)
            targets = process_output(targets, local_rank, non_blocking = True)
        
            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if iteration % args.print_iter == 0:
                if is_main_device:
                    screen = ['[Epoch %d/%d]'% (epoch + 1,args.epochs),
                              '[Batch %d/%d]'%(iteration+1, len(train_loader)),
                              '[lr %f]'%(optimizer.param_groups[0]['lr']),
                              '[loss %0.4f]' % (loss)
                              ]
                    logger.info(''.join(screen))
                    writer.add_scalar("loss/train", loss, global_iteration)
                    global_iteration += 1
                    
        if is_main_device:
            
            ckpt_save = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'schedule' : train_scheduler.state_dict(),
                'last_epoch' : epoch
                
            }
            checkpoint_dirname = args.model +'_loss_' + args.loss
            checkpoint_dirpath = osp.join(args.checkpoint_file_path, checkpoint_dirname)
            if not osp.exists(checkpoint_dirpath):
                os.makedirs(checkpoint_dirpath)
            checkpoint_name = 'checkpoint_epoch_{}_{}_lr_{}.pth'.format(epoch+1, args.epochs, args.lr)
            checkpoint_path = osp.join(checkpoint_dirpath, checkpoint_name)
            torch.save(ckpt_save, checkpoint_path)
        if is_main_device:
            
            logger.info('Start evaluating:')
            logger.info('Creating val dataset.')
            val_dataset = HandLanbanDataset(args.data_file_path, split=args.split_val, aug = args.aug, clip_step = args.clip_step)
            val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = args.batch_size, num_workers = args.num_workers, shuffle = False)
            logger.info('The val dataset is created successfully. ')
            acc_dict = validation(val_loader, model, local_rank)
            if acc_dict['total_laban_acc'] > best_total_acc:
                best_total_acc = acc_dict['total_laban_acc']
                best_total_h = acc_dict['horizontal_laban_acc']
                best_total_v = acc_dict['vertical_laban_acc']
                torch.save(model.state_dict(), os.path.join(checkpoint_dirpath, 'best_total_acc.pth'))
            if acc_dict['horizontal_laban_acc'] > best_horizontal_acc:
                best_horizontal_acc = acc_dict['horizontal_laban_acc']
                best_horizontal_total = acc_dict['total_laban_acc']
                best_horizontal_v = acc_dict['vertical_laban_acc']

                torch.save(model.state_dict(), os.path.join(checkpoint_dirpath, 'best_horizontal_acc.pth'))
            if acc_dict['vertical_laban_acc'] > best_vertical_acc:
                best_vertical_acc = acc_dict['vertical_laban_acc']
                best_vertical_total = acc_dict['total_laban_acc']
                best_vertical_h = acc_dict['horizontal_laban_acc']
                torch.save(model.state_dict(), os.path.join(checkpoint_dirpath, 'best_vertical_acc.pth'))

            writer.add_scalar("total_laban_acc", acc_dict['total_laban_acc'], global_iteration)
            writer.add_scalar("horizontal_laban_acc", acc_dict['horizontal_laban_acc'], global_iteration)
            writer.add_scalar("vertical_laban_acc", acc_dict['vertical_laban_acc'], global_iteration)
            
            
            logger.info('The accuracy of val set is: total_laban_acc : {}, horizontal_laban_acc: {}, vertical_laban_acc: {}'
                .format(acc_dict['total_laban_acc'], acc_dict['horizontal_laban_acc'], acc_dict['vertical_laban_acc']))
        #compute remaining time of training
        finish = time.time()
        epoch_duration = finish - start
        remaining_epochs = args.epochs - epoch -1
        estimated_remaining_time = epoch_duration * remaining_epochs
        hours, rem = divmod(estimated_remaining_time, 3600)
        minutes, seconds = divmod(rem, 60)
        if is_main_device:
            logger.info('epoch {} training time consumed: {:.2f}s'.format(epoch+1, finish-start))
            logger.info('Estimated remaining training time: {:02d}:{:02d}:{:02d}'.format(int(hours), int(minutes), int(seconds)))
            
        logger.info('When taking --total-- as standard:The best accuracy of val set is: total_laban_acc : {}, horizontal_laban_acc: {}, vertical_laban_acc: {}'
                        .format(best_total_acc, best_total_h, best_total_v))
        logger.info('When taking --horizontal-- as standard:The best accuracy of val set is: total_laban_acc : {}, horizontal_laban_acc: {}, vertical_laban_acc: {}'
                        .format(best_horizontal_total, best_horizontal_acc, best_horizontal_v))
        logger.info('When taking --vertical-- as standard:The best accuracy of val set is: total_laban_acc : {}, horizontal_laban_acc: {}, vertical_laban_acc: {}'
                        .format(best_vertical_total, best_vertical_h, best_vertical_acc))
        # writer.add_histogram("weights_hist", model.weight, epoch)
        writer.add_text("text_description", "This is an epoch {}".format(epoch), epoch)

if __name__ == '__main__':
    main()
        
                
    
    
        
    