import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.autograd import Variable

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import time, datetime
import argparse
import numpy as np
from pathlib import Path

import utils
from utils import *
from verbose import Summary, AverageMeter, ProgressMeter, accuracy
from models import *
from dataloader import (
    dataloader_tiny_image_net,
    dataloader_tiny_image_net_supcon
)
from losses import SupConLoss

NET = get_all_networks()


def get_parser():
    parser = argparse.ArgumentParser(description='Image Classification - TinyImageNet')
    parser.add_argument('--data-dir', default='./data', type=str)
    parser.add_argument('--seed-num', default=1, type=int)
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--res-dir', default='./result', type=str)
    parser.add_argument('--ckpt', type=str)

    hyper = parser.add_argument_group('params')
    hyper.add_argument('--net', type=str, choices=list(NET.keys()))
    hyper.add_argument('--lr', default=0.001, type=float)
    hyper.add_argument('--n_epochs', default=5, type=int)
    hyper.add_argument('--batch-size', default=32, type=int)
    hyper.add_argument('--weight-decay', default=0.0, type=float)

    return parser


def main_worker(rank, world_size, args):
    utils.set_random_seed(seed_num=args.seed_num)

    print(f"Running basic DDP example on rank {rank}.")
    utils.setup(rank, world_size)

    if args.net == 'ResNet50TinySupCon':
        train_sampler, train_loader = dataloader_tiny_image_net_supcon(
            args.batch_size, root=args.data_dir, ctg='train')
        test_sampler, test_loader = dataloader_tiny_image_net_supcon(
            args.batch_size, root=args.data_dir, ctg='val')
    else:
        train_sampler, train_loader = dataloader_tiny_image_net(
            args.batch_size, root=args.data_dir, ctg='train')
        test_sampler, test_loader = dataloader_tiny_image_net(
            args.batch_size, root=args.data_dir, ctg='val')

    torch.cuda.set_device(rank)
    if args.net == 'ResNet50TinySupCon':
        net = SupConResNetTiny(name='resnet50tiny', head='linear').cuda(rank)
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank])
    elif args.net == 'ResNet50TinySupLinear':
        pre_trained_model = SupConResNetTiny(name='resnet50tiny', head='linear')
        pre_trained_model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
        pre_trained_model = pre_trained_model.encoder.cuda(rank)
        pre_trained_model.eval()
        pre_trained_model = torch.nn.parallel.DistributedDataParallel(pre_trained_model, device_ids=[rank])
        net = LinearClassifier(name='resnet50tiny', num_classes=200).cuda(rank)
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank])

    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)

    if args.net == 'ResNet50TinySupCon':
        args.gpu_id = rank
        criterion = SupConLoss(args)
    elif args.net == 'ResNet50TinySupLinear':
        criterion = nn.CrossEntropyLoss().cuda(rank)

    def supcon_train(dataloader, epoch):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(
            len(dataloader),
            [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        net.train()

        end = time.time()
        for idx, (data, targets) in enumerate(dataloader):
            data_time.update(time.time() - end)

            data = torch.cat([data[0], data[1]], dim=0)
            data, targets = data.cuda(rank), targets.cuda(rank)

            bsz = targets.shape[0]
            features = net(data)
            f1, f2 = torch.split(features, [bsz, bsz], dim=0)
            features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

            loss = criterion(features, targets)
            #            acc1, acc5 = accuracy(features, targets, topk=(1, 5))
            acc1, acc5 = torch.tensor([0.]), torch.tensor([0.])

            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            top5.update(acc5[0], data.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if rank == 0 and idx % 10 == 0:
                progress.display(idx + 1)

        top1.all_reduce()
        top5.all_reduce()

        return acc1 / 100.

    def train(dataloader, epoch):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(
            len(dataloader),
            [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        net.train()

        end = time.time()
        for idx, (data, targets) in enumerate(dataloader):
            data_time.update(time.time() - end)

            data, targets = data.cuda(rank), targets.cuda(rank)
            with torch.no_grad():
                features = pre_trained_model(data)
            output = net(features.detach())

            loss = criterion(output, targets)
            acc1, acc5 = accuracy(output, targets, topk=(1, 5))

            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            top5.update(acc5[0], data.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if rank == 0 and idx % 10 == 0:
                progress.display(idx + 1)

        return acc1 / 100.

    def test(dataloader):
        batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
        losses = AverageMeter('Loss', ':.4e', Summary.NONE)
        top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
        top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
        progress = ProgressMeter(
            len(dataloader),
            [batch_time, losses, top1, top5],
            prefix="Test: ")

        net.eval()

        with torch.no_grad():
            end = time.time()
            for idx, (data, targets) in enumerate(dataloader):
                data, targets = data.cuda(rank), targets.cuda(rank)

                features = pre_trained_model(data)
                output = net(features)
                loss = criterion(output, targets)

                acc1, acc5 = accuracy(output, targets, topk=(1, 5))
                losses.update(loss.item(), data.size(0))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))

                batch_time.update(time.time() - end)
                end = time.time()

                if rank == 0 and idx % 10 == 0:
                    progress.display(idx + 1)

        top1.all_reduce()
        top5.all_reduce()

        if rank == 0:
            progress.display_summary()

        return top1.avg

    if args.save:
        Path(args.res_dir).mkdir(parents=True, exist_ok=True)
        args.res_tag = '_'.join([args.net,
                                 f'e{args.n_epochs}',
                                 f'lr{args.lr}',
                                 f'bsz{args.batch_size}',
                                 f'seed{args.seed_num}']) + '.pth'

    tr_acc = 0.
    best_tr_acc = 0.
    best_weights = None
    best_acc1 = 0
    lr = args.lr
    if rank == 0:
        total_acc_tr, total_acc_ts = [], []
    for epoch in range(args.n_epochs):
        train_sampler.set_epoch(epoch)
        test_sampler.set_epoch(epoch)

        if epoch % 20 == 0:
            lr = utils.lr_decay(optimizer, lr, decay_rate=0.9)

        if args.net == 'ResNet50TinySupCon':
            tr_acc = supcon_train(train_loader, epoch)
        else:
            tr_acc = train(train_loader, epoch)
        if rank == 0:
            print()

        if args.net == 'ResNet50TinySupLinear':
            acc1 = test(test_loader)

            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if rank == 0:
                total_acc_tr.append(float(tr_acc.detach().cpu()) * 100.)
                total_acc_ts.append(acc1)

            if rank == 0 and args.save and is_best:
                utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'best_acc1': best_acc1,
                }, True, dirname=args.res_dir, filename=args.res_tag)

        if best_tr_acc <= tr_acc:
            best_tr_acc = tr_acc
            if hasattr(net, "module"):
                best_weights = {k: v.to("cpu").clone() for k, v in net.module.state_dict().items()}
            else:
                best_weights = {k: v.to("cpu").clone() for k, v in net.state_dict().items()}

    if rank == 0 and args.save:
        if args.net == 'ResNet50TinySupCon':
            utils.save_checkpoint(
                best_weights, False, dirname=args.res_dir, filename=args.res_tag)
        else:
            utils.save_checkpoint({
                'epoch': args.n_epochs,
                'state_dict': net.state_dict(),
                'best_acc1': best_acc1,
                'total_acc_tr': total_acc_tr,
                'total_acc_ts': total_acc_ts,
            }, False, dirname=args.res_dir, filename=args.res_tag)

    utils.cleanup()


if __name__ == "__main__":
    args = get_parser().parse_args()

    ngpus_per_node = torch.cuda.device_count()
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))


