from Model import Getmodel
from Dataset import *
from Optimizers import *
from stragety import Getloss, CrossEntropyLoss_new
from Valid import valid
import torch
from options import *
import time
from tqdm import tqdm
from Save import *
from Augmentation import get_aug
import pickle
import horovod
import horovod.torch as hvd
from LARC import LARC
from itertools import cycle
import numpy as np
from math import log10

import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from torchvision import datasets, transforms, models
import os
import math
# from torch.utils.tensorboard import SummaryWriter

def train(epoch):
    model.train()
    # Horovod: set epoch to sampler for shuffling.
    train_sampler.set_epoch(epoch)
    L = Getloss(args.stragety)
    L_Contrast = Getloss(args.stragety_Contrast)
    for batch_idx, (data, target)in enumerate(train_loader):
        adjust_learning_rate_new(epoch, batch_idx)
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
#         output,output2 = model(data)
        output = model(data)
#         output = output.clamp_(1e-6)
        target_cal = torch.nn.functional.one_hot(target, num_class).type(torch.float32).cuda()
            
        topology_loss = 0
        aug_loss = 0
        if args.topology != 0:
            for i in range(args.topology):
                ima_aug = aug(data)
                aug_output = model(ima_aug)
#                 aug_output = aug_output.clamp_(1e-6)
                topology_loss += L_Contrast(aug_output, output)
                if args.experince_aug == 1:
                    aug_loss += L_Contrast(model(ima_aug), target_cal)
            topology_loss = topology_loss / args.topology

        elif args.experince_aug == 1:
            ima_aug = aug(data)
            aug_loss += L_Contrast(model(ima_aug), target_cal)
            
        loss = L(output, target_cal) + args.beta * topology_loss       #/(1 + args.beta)
        loss.backward()
        optimizer.step()
    if hvd.rank() == 0:
        print(' ',file=file)
        print("topology_loss:",topology_loss,file=file)
        print(' ',file=file)

def test(dataloader,sampler):
    model.eval()
    test_loss = 0.
    test_accuracy = 0.
    L = Getloss(args.stragety)
    with torch.no_grad():
        for data, target in dataloader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)
            target_cal = torch.nn.functional.one_hot(target, num_class).type(torch.float32).cuda()
            # sum up batch loss
            test_loss += L(output, target_cal)
            # get the index of the max log-probability
            pred = output.data.max(1, keepdim=True)[1]
            test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum()

        # Horovod: use test_sampler to determine the number of examples in
        # this worker's partition.
        test_loss /= len(sampler)
        test_accuracy /= len(sampler)

        # Horovod: average metric values across workers.
        test_loss = metric_average(test_loss, 'avg_loss')
        test_accuracy = metric_average(test_accuracy, 'avg_accuracy')
        
        return test_loss, test_accuracy


# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs.
def adjust_learning_rate_new(epoch, batch_idx):
    if epoch < args.warm_up:
        epoch += float(batch_idx + 1) / len(train_loader)
        lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warm_up + 1)
#         lr_adj = min(0.001+(epoch / args.warm_up),1)
    elif args.lr_policy == 'linear':
        if epoch < 5:
            lr_adj = 1.
        elif epoch < 6:
            lr_adj = 0.31
        elif epoch < 20:
            lr_adj = 0.1
        else:
            lr_adj = 0.03
    elif args.lr_policy == 'cosine':
        lr_adj = 0.5 * (1. + math.cos(math.pi * (epoch-args.warm_up) / (args.Epoch-args.warm_up)))
       #lr_adj = 0.5 * (1. + math.cos(math.pi * (epoch-args.warm_up) / opt.niter))
    elif args.lr_policy == 'logarithmically':
        lr_list = np.hstack((np.logspace(log10(args.base_lr), -4, 7), np.logspace(-4, log10(args.base_lr), 7)))
        lr_list = np.delete(lr_list, -1)
        lr_list = np.delete(lr_list, 7)
        index = (epoch-args.warm_up)%12
        lr = lr_list[index]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr * hvd.size() * args.batches_per_allreduce
    if args.lr_policy != 'logarithmically' or epoch < args.warm_up:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj


if __name__ == '__main__':
    args, parser = parse_args()
    file, time = print_root(args)
    aug = get_aug(image_size=224)
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    allreduce_batch_size = args.batchsize * args.batches_per_allreduce

    hvd.init()
    torch.manual_seed(args.seed)

    if args.cuda:
        # Horovod: pin GPU to local rank.
        torch.cuda.set_device(hvd.local_rank())
        torch.cuda.manual_seed(args.seed)

    cudnn.benchmark = True

    # If set > 0, will resume training from a given checkpoint.
    resume_from_epoch = 0

    # Horovod: broadcast resume_from_epoch from rank 0 (which will have
    # checkpoints) to other ranks.
    resume_from_epoch = hvd.broadcast(torch.tensor(resume_from_epoch), root_rank=0,
                                      name='resume_from_epoch').item()

    # Horovod: print logs on the first worker.
    verbose = 1 if hvd.rank() == 0 else 0

    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(4)

    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
    # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
    # issues with Infiniband implementations that are not fork-safe
    if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
            mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
        kwargs['multiprocessing_context'] = 'forkserver'

    train_dataset, test_dataset, val_dataset = Getdataset(args)
    
    # Horovod: use DistributedSampler to partition data among workers. Manually specify
    # `num_replicas=hvd.size()` and `rank=hvd.rank()`.
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=allreduce_batch_size,
        sampler=train_sampler, **kwargs)

    
    if val_dataset != None:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchsize,
                                                 sampler=val_sampler, **kwargs)
        
    if test_dataset!=None:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=allreduce_batch_size,
                                              sampler=test_sampler, **kwargs)

    # Set up model.
    num_class = Getnumclass(args.task)
    model = Getmodel(args.task, args.backbone, num_class, args.stragety, args.pretrain_param)
    
    if args.resume_on==1:
        checkpoint = torch.load("./ch_log/9月19晚_接着有监督跑imagenet_自己的损失_224*8/resnet50_cosloss_newsigmoid_imagenet_BS_8*224_lr0.1_warm5_BN_(09月20日-05时10分)/checkpoints/_epoch8.pt")
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
    
    num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
    max_step = float(args.Epoch)*50000/float(allreduce_batch_size*hvd.size())
    warmup_steps = float(args.warm_up)*50000/float(allreduce_batch_size*hvd.size())
    
    if hvd.rank() == 0:
        print(model,file=file)
        print_options(parser, args, file=file)
        print('number of parameters: ', num_param,file=file)
        print('all step: ', max_step, file=file)

    # By default, Adasum doesn't need scaling up learning rate.
    # For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
    lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1

    if args.cuda:
        # Move model to GPU.
        model.cuda()
#         print('model.cuda()')
        # If using GPU Adasum allreduce, scale learning rate by local_size.
        if args.use_adasum and hvd.nccl_built():
            lr_scaler = args.batches_per_allreduce * hvd.local_size()

    # Horovod: scale learning rate by the number of GPUs.
    optimizer = optim.SGD(model.parameters(),
                          lr=(args.base_lr *
                              lr_scaler),
                          momentum=0.9, weight_decay=1e-4, nesterov=True)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters(),
        compression=compression,
        backward_passes_per_step=args.batches_per_allreduce,
        op=hvd.Adasum if args.use_adasum else hvd.Average,
        gradient_predivide_factor=args.gradient_predivide_factor)
    
    if args.lars:
        if hvd.rank() == 0:
            print("=> use LARS optimizer.",file=file)
        optimizer = LARC(optimizer=optimizer, trust_coefficient=.001, clip=False)

    # Restore from a previous checkpoint, if initial_epoch is specified.     
    # Horovod: restore on the first worker which will broadcast weights to other workers.
    if resume_from_epoch > 0 and hvd.rank() == 0:
        filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        

    if args.resume_on==1:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])# 优化器恢复
    
    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    step = 0
    result = {'epoch': [], 'train': {'acc': [], 'loss': []}, 'val': {'acc': [], 'loss': []}, 'test': {'acc': [], 'loss': []}}
    if args.resume_on ==1 :
        for epoch in tqdm(range(checkpoint['epoch'], args.Epoch),file=file):
            print(epoch)
            train(epoch)
        #if val_dataset != None:
        #    valid(epoch)
            L = Getloss(args.stragety)
            train_loss, train_acc = valid(train_loader,train_sampler, model, args, num_class, L)
            if test_dataset!=None:
                test_loss, test_acc = valid(test_loader,test_sampler, model, args, num_class, L)
            else:
                test_loss, test_acc = 0.,0.
#         train_loss, train_acc = test(train_loader,train_sampler)
#         test_loss, test_acc = test(test_loader,test_sampler)
            if val_dataset != None:
#             val_loss, val_acc = test(val_loader,val_sampler)
                val_loss, val_acc = valid(val_loader,val_sampler, model, args, num_class, L)
            else:
                val_loss, val_acc = 0.,0.
            
            result['val']['acc'].append(val_acc)
            result['val']['loss'].append(val_loss)
            result['train']['acc'].append(train_acc)
            result['train']['loss'].append(train_loss)
            result['test']['acc'].append(test_acc)
            result['test']['loss'].append(test_loss)
            result['epoch'].append(epoch)
        
            if hvd.rank() == 0:
                print('\n[{:s}]\t\t      {:.6f}'.format('LR', optimizer.state_dict()['param_groups'][0]['lr']), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Train', train_loss, 'Acc', train_acc), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Test', test_loss, 'Acc', test_acc), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}\n'.format('Val', val_loss, 'Acc', val_acc), file=file)
        
            if args.is_picture:
                picture(args, result, time)
          
            if args.save_best_function==1:
                if epoch<args.begin_select_epoch:
                    if epoch % args.save_epoch == 0:
                        model_save(args, model, optimizer, epoch, result, time)
                elif epoch==args.begin_select_epoch:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
                elif epoch>args.begin_select_epoch and val_acc-best_val_acc>=args.update_gate:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
                elif epoch>args.begin_select_epoch and val_acc-best_val_acc>0 and epoch - best_epoch >=args.save_across_epoch:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
            else:
                if epoch % args.save_epoch == 0:
                    model_save(args, model, optimizer, epoch, result, time)
        
        
        
            value_save(result,args,time)
    else:
        for epoch in tqdm(range(0, args.Epoch),file=file):
#             print(epoch)
            train(epoch)
        #if val_dataset != None:
        #    valid(epoch)
            L = Getloss(args.stragety)
            train_loss, train_acc = valid(train_loader,train_sampler, model, args, num_class, L)
            if test_dataset!=None:
                test_loss, test_acc = valid(test_loader,test_sampler, model, args, num_class, L)
            else:
                test_loss, test_acc = 0.,0.
#         train_loss, train_acc = test(train_loader,train_sampler)
#         test_loss, test_acc = test(test_loader,test_sampler)
            if val_dataset != None:
#             val_loss, val_acc = test(val_loader,val_sampler)
                val_loss, val_acc = valid(val_loader,val_sampler, model, args, num_class, L)
            else:
                val_loss, val_acc = 0.,0.
            
            result['val']['acc'].append(val_acc)
            result['val']['loss'].append(val_loss)
            result['train']['acc'].append(train_acc)
            result['train']['loss'].append(train_loss)
            result['test']['acc'].append(test_acc)
            result['test']['loss'].append(test_loss)
            result['epoch'].append(epoch)
        
            if hvd.rank() == 0:
                print('\n[{:s}]\t\t      {:.6f}'.format('LR', optimizer.state_dict()['param_groups'][0]['lr']), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Train', train_loss, 'Acc', train_acc), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}'.format('Test', test_loss, 'Acc', test_acc), file=file)
                print('[{:s}]\t\tLoss: {:.8f}, {:s}: {:.4f}\n'.format('Val', val_loss, 'Acc', val_acc), file=file)
        
            if args.is_picture:
                picture(args, result, time)
          
            if args.save_best_function==1:
                if epoch<args.begin_select_epoch:
                    if epoch % args.save_epoch == 0:
                        model_save(args, model, optimizer, epoch, result, time)
                elif epoch==args.begin_select_epoch:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
                elif epoch>args.begin_select_epoch and val_acc-best_val_acc>=args.update_gate:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
                elif epoch>args.begin_select_epoch and val_acc-best_val_acc>0 and epoch - best_epoch >=args.save_across_epoch:
                    best_val_acc=val_acc
                    best_epoch=epoch
                    best_save(args, model, optimizer, epoch, result, time)
            else:
                if epoch % args.save_epoch == 0:
                    model_save(args, model, optimizer, epoch, result, time)
        
        
        
            value_save(result,args,time)

