from Model import Getmodel
from Dataset import Getdataset, Getnumclass
from Optimizers import *
from stragety import Getloss, CrossEntropyLoss_new
from Valid import valid_old
import torch
from options import *
import time
from tqdm import tqdm
from Save import *
from Augmentation import get_aug
import pickle
import horovod
import horovod.torch as hvd
from distutils.version import LooseVersion
# from torchkeras import summary
# from torch.utils.tensorboard import SummaryWriter
# from batchgenerators.utilities.file_and_folder_operations import *

if __name__ == "__main__":
    # get arguments
    args, parser = parse_args()
    
#     torch.manual_seed(args.seed)
#     torch.cuda.manual_seed(args.seed)

    # get number of class
    num_class = Getnumclass(args.task)
    # get the method of augmentation
    aug = get_aug(image_size=224)
    train_dataset, test_dataset, val_dataset = Getdataset(args)
    
    print('train_dataset',len(train_dataset))
    if test_dataset != None:
        print('test_dataset',len(test_dataset))
    if val_dataset != None:
        print('val_dataset',len(val_dataset))
        
    model = Getmodel(args.task, args.backbone, num_class, args.stragety, args.pretrain_param)
    
    if args.resume_on == 1:
        checkpoint = torch.load('')
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
    
    model = model.cuda()
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
    
    Hparam = get_Hparam(args, model)
    optimizer = Getoptim(args.CLoptimizer, Hparam)
    
#     scheduler = define_scheduler(args, optimizer)
    if args.resume_on == 1:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     model = model.to(device)

    lossf = Getloss(args.stragety)

    # use to record the result
    result = {'epoch': [], 'train': {'acc': [], 'loss': []}, 'val': {'acc': [], 'loss': []}, 'test': {'acc': [], 'loss': []}}
    file, time = print_root(args)
    print(model,file=file)
    num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of parameters: ', num_param,file=file)
    print_options(parser, args, file=file)
    # begin to train
    max_step = float(args.Epoch)*len(train_dataset)/float(args.batchsize)
    warmup_steps = float(args.warm_up)*len(train_dataset)/float(args.batchsize)
    print('all step: ', max_step, file=file)
    
    if args.resume_on == 1:
        start = checkpoint['epoch']
        step = len(train_dataset)*start/args.batchsize
        print(step)
    else:
        start = 0
        step = 0
    
    for epoch in tqdm(range(start, args.Epoch), file=file):
        total_loss = 0
        train_correct = 0
        test_correct = 0

        for id, batch in enumerate(train_loader):
            images, labels = batch
            images = images.cuda()
            labels = labels.cuda()
            labels_cal = torch.nn.functional.one_hot(labels, num_class).type(torch.float32).cuda()
            outs = model(images)
            
            step += 1
            adjust_learning_rate(optimizer, max_step, warmup_steps, step, args)

            # topology loss by metic
            topology_loss = 0
            aug_loss = 0
            lossf_topology = Getloss(args.stragety)
            if args.topology != 0:
                for i in range(args.topology):
                    # dxi = args.Q * 2 * (torch.rand(images.shape) - 0.5).to(images.device) + images
                    ima_aug = aug(images)
                    topology_loss += lossf_topology(model(ima_aug), outs)

                topology_loss = topology_loss / args.topology

            exp_loss = lossf(outs, labels_cal)
            loss = exp_loss + aug_loss + args.beta * topology_loss

            optimizer.zero_grad()
            loss.backward()

            if args.resume_on == 1:
                optimizer.param_groups[0]['capturable'] = True
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        print(scheduler.get_lr(), file=file)
        print(optimizer.state_dict()['param_groups'][0]['lr'], file=file)

        train_accurary, train_loss, train_tloss = valid_old(model, train_dataset, args, lossf, num_class, aug)
        test_accurary, test_loss, test_tloss = valid_old(model, test_dataset, args, lossf, num_class, aug)
        
        if val_dataset != None:
            val_accurary, val_loss, val_tloss = valid_old(model, val_dataset, args, lossf, num_class, aug)
        else:
            val_accurary, val_loss, val_tloss = 0, 0, 0
            
        print(' ', file=file)
        print('train_tloss:', train_tloss, 'test_tloss:', test_tloss, 'val_tloss', val_tloss, file=file)
        print(' ', file=file)
        
        result['val']['acc'].append(val_accurary)
        result['val']['loss'].append(val_loss)
        result['train']['acc'].append(train_accurary)
        result['train']['loss'].append(train_loss)
        result['test']['acc'].append(test_accurary)
        result['test']['loss'].append(test_loss)
        result['epoch'].append(epoch)

        print("{epoch_index} epoch's train_accurary is {trainacc},train_loss is {losst}, test_accurary is {testacc}.".format(epoch_index=epoch,trainacc=train_accurary,losst=total_loss ,testacc=test_accurary))
        print('\n[{:s}]\t\t      {:.6f}'.format('LR', optimizer.state_dict()['param_groups'][0]['lr']), file=file)
        print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Train', train_loss, 'Acc', train_accurary), file=file)
        print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Test', test_loss, 'Acc', test_accurary), file=file)
        print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}\n'.format('Val', val_loss, 'Acc', val_accurary), file=file)

        if args.is_picture:
            picture(args, result, time)

        # model save
        if epoch % args.save_epoch == 0:
            model_save(args, model, optimizer, epoch, loss, time)
            
        value_save(result,args,time)

