__author__ = "Anon"
__version__ = "0.1"
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import os
from utils import ExperimentSettings, validate, get_net, WarmUpLR
from pathlib import Path
from datasets import get_train_valid_loader, get_test_loader, TransformComposer
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import argparse
import yaml
import matplotlib.pyplot as plt
import numpy as np
import itertools
import time
import torch.multiprocessing as mp
import torch.distributed as dist

def main(args):
    args.distributed = True
    ngpus_per_node = torch.cuda.device_count()
    args.world_size = ngpus_per_node * args.world_size
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu

    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)
    print(args.rank, args.gpu)
    config = yaml.load(open(args.config))
    if args.rank % ngpus_per_node == 0:
        TIME_NOW = str(datetime.now().strftime('%Y-%m-%d--%H-%M'))
        NOTE = 'Any comments'
        config['TIME_NOW'] = TIME_NOW
        ###
        settings = ExperimentSettings(config)
        ### CHECKPOINT PATHS
        checkpoint_path = os.path.join('checkpoints', config['DATASET'], '{}_Pretrained-{}'.format(config['ARCH'], config['PRETRAINED']),
                                       '_'.join(config['TRANSFORMS']), 'num_negatives_{}'.format(config['num_negatives']),
                                       'act_{}'.format(config['class_in_activation']), config['TIME_NOW'])
        log_path = os.path.join(checkpoint_path, 'logs')
        model_ckp_path = os.path.join(checkpoint_path, '{bin}-best.pth')
        if not os.path.exists(checkpoint_path):
            Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
            Path(log_path).mkdir(parents=True, exist_ok=True)
        settings.dump(checkpoint_path)

        ### SUMMARY WRITER
        writer = SummaryWriter(log_path)

        ### Saving Best Model
        best_acc = 0.

    ### TRANSFORMS
    train_transform_composer = TransformComposer(transforms=config['TRANSFORMS'], dataset=config['DATASET'], inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
    if config['DATASET'] in ['imgnet', 'cub200', 'pets', 'cub20', 'bmw10', 'cars']:
        transforms = ['RESIZE', 'CCROP', 'NORM']
    elif config['DATASET'] in ['cifar10', 'cifar100', 'stl10']:
        transforms = ['NORM']
    else:
        raise NotImplementedError('Dataset not implemented')
    test_transform_composer = TransformComposer(transforms=transforms, dataset=config['DATASET'],
                                                inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])

    print(config)

    net = get_net(config['ARCH'], config['NUM_CLASSES'], config['PRETRAINED'], config['LOAD_WEIGHTS_FROM'], config)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        # When using a single GPU per process and per
        # DistributedDataParallel, we need to divide the batch size
        # ourselves based on the total number of GPUs we have
        args.batch_size = int(config['BATCH_SIZE'] / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
    else:
        net.cuda()
        net = torch.nn.parallel.DistributedDataParallel(net)

    #Change
    ### Iterate over different batches
    for index, num_bin in enumerate(config['NUM_BINS']):

        if config['pos_neg_balanced']:
            # weights = torch.tensor([1/(config['num_negatives']+1), config['num_negatives']/(config['num_negatives']+1)]).cuda()
            weights = torch.tensor([0.25, 1]).cuda(args.gpu)
        else:
            weights = None
        criterion = nn.CrossEntropyLoss(weight=weights).cuda(args.gpu)

        transform = train_transform_composer.get_composite(num_bins=num_bin)
        test_transform = test_transform_composer.get_composite(num_bins=256)

        ### Dataloaders
        train_loader, train_sampler = get_train_valid_loader(dataset=config['DATASET'], batch_size=args.batch_size, num_workers=args.workers
                                                          , transform=transform, pin_memory=True,opts=config, is_mgpu=True)
        test_b_size = 64
        test_loader = get_test_loader(dataset=config['DATASET'], batch_size=test_b_size, pin_memory=True, transform=test_transform)
        
        #Hyper-params
        lr, epochs, milestones = config['LEARNING_RATES'][str(num_bin)]
        #optimizer = optim.RMSprop(net.parameters(), lr=lr, momentum=0.9)
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2)
        iter_per_epoch = len(train_loader)
        warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * 2)
        time_elapsed = 0.
        for epoch in range(epochs+1):
            train_sampler.set_epoch(epoch)

            ### TRAIN STEP
            progress = tqdm(enumerate(train_loader), desc="Bin: {} & Epoch: {}".format(num_bin, epoch), total=len(train_loader))
            net.train()
            start_time = time.time()
            loss = 0.
            for iter, data in progress:
                images, cls_y = data[0], data[1]
                optimizer.zero_grad()
                images = images.cuda(args.gpu, non_blocking=True)
                cls_y = cls_y.cuda(args.gpu, non_blocking=True)
                preds = net(images)
                cls_loss = criterion(preds, cls_y)

                cls_loss.backward()
                loss += cls_loss.item()
                optimizer.step()
                progress.update(1)
            time_elapsed += time.time() - start_time
            print('Time per epoch: {} -- {}'.format(time_elapsed/(epoch+1), loss))
            if epoch <= 2:
                warmup_scheduler.step()
            else:
                scheduler.step()

            del images, data, cls_y, preds

            if args.rank % ngpus_per_node == 0:
                if epoch > int(epochs*0.9) or epoch % 10 == 0:
                    start_time = time.monotonic()
                    _, _, cls_t1, cls_t5, _, _ = validate(net, test_loader, config)
                    print('validation time: {}{}'.format(time.monotonic()-start_time, cls_t1))
                    ### Write to Tensorboard
                    writer.add_scalars('ACC'.format(num_bin), {'cls_t1': cls_t1, 'cls_t5': cls_t5}, global_step=epoch+1)
                    ### Save best model
                    if best_acc <= cls_t1:
                        torch.save(net.module.state_dict(), model_ckp_path.format(bin=num_bin))
                        best_acc = cls_t1



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training for supervised classification experiments')
    parser.add_argument('--config', '-c', help='Path to the training config.yaml', required=True)
    parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--rank', default=0, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    args, unknown = parser.parse_known_args()
    main(args)
