import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable

import argparse
import time, datetime
import numpy as np
from pathlib import Path
import os

import utils
from utils import *
from models import *
from losses import *

NET = get_all_networks()

parser = argparse.ArgumentParser(description='Image Classification - CIFAR10LT / CIFAR100LT')
parser.add_argument('--gpu-id', default=0, type=int)
parser.add_argument('--data-dir', default='./data', type=str)
parser.add_argument('--net', default='resnet32_origin', type=str, choices=list(NET.keys()))
parser.add_argument('--seed-num', default=1, type=int)
parser.add_argument('--save', action='store_true', default=False)
parser.add_argument('--res-dir', default='./result', type=str)
parser.add_argument('--res-tag', default='vgg11', type=str)
parser.add_argument('--log-dir', default='./runs', type=str)
parser.add_argument('--imb-factor', default=0.1, type=float)
parser.add_argument('--ckpt', default=None, type=str)

hyper = parser.add_argument_group('params')
hyper.add_argument('--lr', default=0.05, type=float)
hyper.add_argument('--n_epochs', default=300, type=int)
hyper.add_argument('--batch-size', default=128, type=int)
hyper.add_argument('--momentum', default=0.9, type=float)
hyper.add_argument('--weight-decay', default=5e-4, type=float)
hyper.add_argument('--dataset', default='cifar100', type=str)
hyper.add_argument('--beta', default=2./3., type=float)

args = parser.parse_args()

t_start = time.time()

utils.set_random_seed(seed_num=args.seed_num)
device = torch.device('cuda:{}'.format(args.gpu_id))

TRANSFORMATION = get_all_transformations()

transform_tr = TRANSFORMATION[args.dataset+'_train']
transform_ts = TRANSFORMATION[args.dataset+'_test']
if 'supCon' in args.net:
    transform_tr = TwoCropTransform(transform_tr)

DATASETS = get_all_datasets(args.data_dir, transform_tr, transform_ts, args.imb_factor)
dataset = DATASETS[args.dataset+'_dataset']
testset = DATASETS[args.dataset+'_testset']

trainloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=args.batch_size * 2, shuffle=False)

class_name = dataset.classes
num_classes = len(class_name)

print(f'net: {args.net} / dataset: {args.dataset} / num_classes: {num_classes}')

if "AMA" in args.net:
    net = NET[args.net](num_classes=num_classes, beta=args.beta).to(device)
elif "supCon" in args.net:
    net = NET[args.net]().to(device)
else:
    net = NET[args.net](num_classes=num_classes).to(device)

if "supLinear" in args.net:
    pre_trained_model = NET['resnet32_supCon']()
    pre_trained_model.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
    pre_trained_model.to(device)


net = net.to(device)

if 'supCon' in args.net:
    criterion = SupConLoss(args)
else:
    criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
writer = SummaryWriter(args.log_dir)
best_dev_acc = [-1]
best_weights = [None]
tr_acc = 0.
print(net)


def supcon_train(dataloader, epoch):
    net.train()
    tr_loss = 0.
    correct = 0

    prev_time = time.time()
    for idx, (data, targets) in enumerate(dataloader):
        data = torch.cat([data[0], data[1]], dim=0)
        data, targets = data.to(device), targets.to(device)
        bsz = targets.shape[0]
        optimizer.zero_grad()
        features = net(data)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        loss = criterion(features, targets)
        tr_loss += loss.item()
        loss.backward()
        optimizer.step()

        # verbose
        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx + 1, len(dataloader), loss, time_left), end=' ')

    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    tr_acc = correct / cnt
    return tr_loss, tr_acc


def train(dataloader, epoch, acc):
    if 'supLinear' in args.net:
        pre_trained_model.eval()
    net.train()
    tr_loss = 0.
    correct = 0
    total = 0
    prev_time = time.time()
    for idx, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()
        if 'AMA' in args.net:
            output, targets = net(data, tr_acc=acc, label=targets)
            loss = criterion(output, targets)
            tr_loss += loss.item()
        elif 'supLinear' in args.net:
            with torch.no_grad():
                features = pre_trained_model.encoder(data)
            output = net(features.detach())
            loss = criterion(output, targets)
            tr_loss += loss.item()
        elif 'manifoldMixup' in args.net:
            input_var, target_var = Variable(data), Variable(targets)
            output, targets_a, targets_b, lam = net(input_var, target_var, device)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
            tr_loss += loss.item()

        elif 'mixup' in args.net:
            data, targets_a, targets_b, lam = mixup_data(data, targets, 1.0, device)
            data, targets_a, targets_b = map(Variable, (data, targets_a, targets_b))
            output = net(data)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
            tr_loss += loss.item()

        else:
            output = net(data)
            loss = criterion(output, targets)
            tr_loss += loss.item()

        loss.backward()
        optimizer.step()
        if 'AMA' in args.net:
            pred = output[:data.shape[0]].argmax(dim=1, keepdim=True)
            correct += pred.eq(targets[:data.shape[0]].view_as(pred)).sum().item()
        else:
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()

        # verbose
        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx + 1, len(dataloader), loss, time_left), end=' ')

    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    tr_acc = correct / cnt
    return tr_loss, tr_acc


def test(dataloader):
    net.eval()
    ts_loss = 0.
    correct = 0

    with torch.no_grad():
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)
            if 'supLinear' in args.net:
                features = pre_trained_model.encoder(data)
                output = net.fc(features)
            else:
                output = net(data)

            loss = criterion(output, targets)
            ts_loss += loss.item()

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()

    cnt = len(dataloader.dataset)
    ts_loss /= cnt
    ts_acc = correct / cnt
    if ts_acc > best_dev_acc[0]:
        if hasattr(net, "module"):
            best_weights[0] = {k: v.to("cpu").clone() for k, v in net.module.state_dict().items()}
        else:
            best_weights[0] = {k: v.to("cpu").clone() for k, v in net.state_dict().items()}
        best_dev_acc[0] = ts_acc
    return ts_loss, ts_acc


if args.save:
    result = []  # [(tr_loss, tr_acc, ts_loss, ts_acc)]

if 'supCon' in args.net:
    lr = 0.0
    best_train_loss = [1e9]
    for epoch in range(1, args.n_epochs + 1):
        if epoch <= 10:
            lr += 0.015
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if epoch in [360, 380]:
            lr = utils.lr_decay(optimizer, lr, decay_rate=0.1)
        tr_loss, tr_acc = supcon_train(trainloader, epoch)
        writer.add_scalar('train_loss', tr_loss, epoch)
        writer.add_scalar('train_acc', tr_acc, epoch)
        print("loss: {:.4f}, acc: {:.4f} ".format(tr_loss, tr_acc), end='')
        if args.save:
            result.append([tr_loss, tr_acc, 0., 0.])

        if best_train_loss[0] > tr_loss:
            best_train_loss[0] = tr_loss
            if hasattr(net, "module"):
                best_weights[0] = {k: v.to("cpu").clone() for k, v in net.module.state_dict().items()}
            else:
                best_weights[0] = {k: v.to("cpu").clone() for k, v in net.state_dict().items()}
    print("\n[ Elapsed Time: {:.4f} ]".format(time.time() - t_start))
    if args.save:
        Path(args.res_dir).mkdir(parents=True, exist_ok=True)
        torch.save(best_weights[0], os.path.join(args.res_dir, "pytorch_model.bin"))
        res_filepath = Path(args.res_dir) / 'result_{}.npy'.format(args.res_tag)
        np.save(res_filepath, np.array(result))
else:
    lr = 0.0
    for epoch in range(1, args.n_epochs + 1):
        #    lr = utils.lr_scheduling(optimizer, epoch, args.lr)
        if epoch <= 10:
            lr += 0.015
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if epoch in [360, 380]:
            lr = utils.lr_decay(optimizer, lr, decay_rate=0.1)

        tr_loss, tr_acc = train(trainloader, epoch, acc=tr_acc)
        ts_loss, ts_acc = test(testloader)
        writer.add_scalar('train_loss', tr_loss, epoch)
        writer.add_scalar('train_acc', tr_acc, epoch)
        writer.add_scalar('test_loss', ts_loss, epoch)
        writer.add_scalar('test_acc', ts_acc, epoch)
        print("loss: {:.4f}, acc: {:.4f} ".format(tr_loss, tr_acc)
              + "/ test_loss: {:.4f}, test_acc: {:.4f}".format(ts_loss, ts_acc), end='')

        if args.save:
            result.append([tr_loss, tr_acc, ts_loss, ts_acc])

    print("\n[ Elapsed Time: {:.4f} ]".format(time.time() - t_start))

    if args.save:
        Path(args.res_dir).mkdir(parents=True, exist_ok=True)
        torch.save(best_weights[0], os.path.join(args.res_dir, "pytorch_model.bin"))
        res_filepath = Path(args.res_dir) / 'result_{}.npy'.format(args.res_tag)
        np.save(res_filepath, np.array(result))

