import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader

import time, datetime
import argparse
import os
from pathlib import Path

import utils
from utils import *
from models.transformations import TwoCropTransform
from losses import SupConLoss
from models import *
parser = argparse.ArgumentParser(description='Image Classification - CIFAR10 / CIFAR100 / TinyImageNet')
parser.add_argument('--gpu-id', default=0, type=int)
parser.add_argument('--data-dir', default='./data', type=str)
parser.add_argument('--seed-num', default=1, type=int)
parser.add_argument('--save', action='store_true', default=False)
parser.add_argument('--res-dir', default='./result', type=str)
parser.add_argument('--res-tag', default='vgg11', type=str)
parser.add_argument('--log-dir', default='./runs', type=str)
parser.add_argument('--ckpt', type=str, default='')

hyper = parser.add_argument_group('params')
hyper.add_argument('--lr', default=0.01, type=float)
hyper.add_argument('--step', default=20, type=int)
hyper.add_argument('--decay-rate', default=.1, type=float)
hyper.add_argument('--n-epochs', default=200, type=int)
hyper.add_argument('--batch-size', default=128, type=int)
hyper.add_argument('--momentum', default=0.9, type=float)
hyper.add_argument('--weight-decay', default=5e-4, type=float)
hyper.add_argument('--net', default='vgg11_origin', type=str)
hyper.add_argument('--dataset', default="cifar100", type=str)
hyper.add_argument('--beta', default=2./3., type=float)
args = parser.parse_args()

print(f'net: {args.net} / dataset: {args.dataset}')

utils.set_random_seed(seed_num=args.seed_num)
device = torch.device('cuda:{}'.format(args.gpu_id))

TRANSFORMATION = get_all_transformations()

transform_tr = TRANSFORMATION[args.dataset+'_train']
transform_ts = TRANSFORMATION[args.dataset+'_test']


NET = get_all_networks()

if 'supCon' in args.net:
    transform_tr = TwoCropTransform(transform_tr)

DATASETS = get_all_datasets(args.data_dir, transform_tr, transform_ts)
dataset = DATASETS[args.dataset+'_dataset']
testset = DATASETS[args.dataset+'_testset']

trainloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=args.batch_size * 2, shuffle=False)

class_name = dataset.classes
num_classes = len(class_name)

if 'vgg11' in args.net:
    net = NET[args.net](args, num_classes).to(device)
elif 'resnet50' in args.net:
    if 'AMA' in args.net:
        net = NET[args.net](num_classes=num_classes, beta=args.beta).to(device)
    elif 'supCon' in args.net:
        net = NET[args.net]().to(device)
    else:
        net = NET[args.net](num_classes=num_classes).to(device)
    if 'supLinear' in args.net:
        pre_trained_model = NET['resnet50_supCon'](name='resnet50', head='linear')
        pre_trained_model.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
        pre_trained_model.to(device)
elif 'dense' in args.net:
    if 'AMA' in args.net:
        net = NET[args.net](num_classes=num_classes, beta=args.beta).to(device)
    else:
        net = NET[args.net](num_classes=num_classes).to(device)

if 'supCon' in args.net:
    criterion = SupConLoss(args)
else:
    criterion = nn.CrossEntropyLoss()

softmax = nn.Softmax(dim=1)
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
writer = SummaryWriter(args.log_dir)
best_dev_acc = [-1]
best_weights = [None]
tr_acc = 0.
print(net)


def train(dataloader, epoch, acc):
    net.train()
    if 'supLinear' in args.net:
        pre_trained_model.eval()
    tr_loss = 0.
    correct = 0
    total = 0
    prev_time = time.time()

    for idx, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()

        if 'AMA' in args.net:
            output, targets = net(data, tr_acc=acc, label=targets)
            loss = criterion(output, targets)
            tr_loss += loss.item()
        elif 'supLinear' in args.net:
            with torch.no_grad():
                features = pre_trained_model.encoder(data)
            output = net(features.detach())
            loss = criterion(output, targets)
            tr_loss += loss.item()
        elif 'mixup' in args.net:
            data, targets_a, targets_b, lam = mixup_data(data, targets, 1.0, device)
            data, targets_a, targets_b = map(Variable, (data, targets_a, targets_b))
            output = net(data)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
            tr_loss += loss.item()
        elif 'manifoldMixup' in args.net:
            input_var, target_var = Variable(data), Variable(targets)
            output, targets_a, targets_b, lam = net(input_var, target_var, device)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
            tr_loss += loss.item()
        else:
            output = net(data)
            loss = criterion(output, targets)
            tr_loss += loss.item()

        loss.backward()
        optimizer.step()

        if 'AMA' in args.net:
            pred = output[:data.shape[0]].argmax(dim=1, keepdim=True)
            correct += pred.eq(targets[:data.shape[0]].view_as(pred)).sum().item()
        else:
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()

        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx + 1, len(dataloader), loss, time_left), end=' ')

    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    tr_acc = correct / cnt

    return tr_loss, tr_acc


def supcon_train(dataloader, epoch):
    net.train()
    tr_loss = 0.
    correct = 0
    prev_time = time.time()

    for idx, (data, targets) in enumerate(dataloader):
        data = torch.cat([data[0], data[1]], dim=0)
        data, targets = data.to(device), targets.to(device)
        bsz = targets.shape[0]
        optimizer.zero_grad()
        features = net(data)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        loss = criterion(features, targets)
        tr_loss += loss.item()
        loss.backward()
        optimizer.step()

        batches_done = (epoch - 1) * len(dataloader) + idx
        batches_left = args.n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print("\r[epoch {:3d}/{:3d}] [batch {:4d}/{:4d}] loss: {:.6f} (eta: {})".format(
            epoch, args.n_epochs, idx + 1, len(dataloader), loss, time_left), end=' ')

    cnt = len(dataloader.dataset)
    tr_loss /= cnt
    tr_acc = correct / cnt

    return tr_loss, tr_acc


def test(dataloader):
    net.eval()
    ts_loss = 0.
    correct = 0

    with torch.no_grad():
        for idx, (data, targets) in enumerate(dataloader):
            data, targets = data.to(device), targets.to(device)

            if 'vgg' in args.net:
                features = net.features(data)
                features = features.view(features.size(0), -1)
                output = net.classifier(features)

            elif 'resnet' in args.net:
                if 'supLinear' in args.net:
                    features = pre_trained_model.encoder(data)
                    output = net.fc(features)
                else:
                    features = net.encoder(data)
                    output = net.fc(features)

            elif 'dense' in args.net:
                out = net.conv1(data)
                out = net.trans1(net.dense1(out))
                out = net.trans2(net.dense2(out))
                out = net.dense3(out)
                features = torch.squeeze(F.avg_pool2d(F.relu(net.bn1(out)), 8))
                output = net.fc(features)

            loss = criterion(output, targets)
            ts_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()

    cnt = len(dataloader.dataset)
    ts_loss /= cnt
    ts_acc = correct / cnt
    if ts_acc > best_dev_acc[0]:
        if hasattr(net, "module"):
            best_weights[0] = {k: v.to("cpu").clone() for k, v in net.module.state_dict().items()}
        else:
            best_weights[0] = {k: v.to("cpu").clone() for k, v in net.state_dict().items()}
        best_dev_acc[0] = ts_acc
    return ts_loss, ts_acc

if args.save:
    result = []

if 'supCon' in args.net:
    lr = args.lr
    best_train_loss = [1e9]
    for epoch in range(1, args.n_epochs + 1):
        if args.dataset == 'cifar10':
            lr = utils.lr_scheduling(optimizer, epoch, args.lr)
        else:
            if epoch in [60, 120, 160]:
                lr = utils.lr_decay(optimizer, lr)
        tr_loss, tr_acc = supcon_train(trainloader, epoch)
        writer.add_scalar('train_loss', tr_loss, epoch)
        writer.add_scalar('train_acc', tr_acc, epoch)
        print("loss: {:.4f}, acc: {:.4f} ".format(tr_loss, tr_acc), end='')
        if args.save:
            result.append([tr_loss, tr_acc, 0., 0.])
        if best_train_loss[0] > tr_loss:
            best_train_loss[0] = tr_loss
            if hasattr(net, "module"):
                best_weights[0] = {k: v.to("cpu").clone() for k, v in net.module.state_dict().items()}
            else:
                best_weights[0] = {k: v.to("cpu").clone() for k, v in net.state_dict().items()}

    if args.save:
        Path(args.res_dir).mkdir(parents=True, exist_ok=True)
        torch.save(best_weights[0], os.path.join(args.res_dir, "pytorch_model.bin"))
        res_filepath = Path(args.res_dir) / 'result_{}.npy'.format(args.res_tag)
        np.save(res_filepath, np.array(result))

else:
    lr = args.lr
    for epoch in range(1, args.n_epochs + 1):
        if args.dataset == 'cifar10':
            lr = utils.lr_scheduling(optimizer, epoch, args.lr)
        else:
            if epoch in [60, 120, 160]:
                lr = utils.lr_decay(optimizer, lr)

        tr_loss, tr_acc = train(trainloader, epoch, tr_acc)
        test_loss, test_acc = test(testloader)
        writer.add_scalar('train_loss', tr_loss, epoch)
        writer.add_scalar('train_acc', tr_acc, epoch)
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)
        print("loss: {:.4f}, acc: {:.4f} ".format(tr_loss, tr_acc)
              + "/ test_loss: {:.4f}, test_acc: {:.4f}".format(test_loss, test_acc), end='')
        if args.save:
            result.append([tr_loss, tr_acc, test_loss, test_acc])

    test_loss, test_acc = test(testloader)
    print(
        f"test_loss: {test_loss:.4f}, test_acc: {test_acc:.4f}")
    if args.save:
        Path(args.res_dir).mkdir(parents=True, exist_ok=True)
        torch.save(best_weights[0], os.path.join(args.res_dir, "pytorch_model.bin"))
        res_filepath = Path(args.res_dir) / 'result_{}.npy'.format(args.res_tag)
        np.save(res_filepath, np.array(result))