import torch
import os
from datetime import datetime
import fire
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
from utils import accuracy, kl_loss
from mmce import calibration_mmce_w_loss, calibration_mmce_w_loss_entr
from uce import uceloss, uceloss_weighted

start_time = str(datetime.now()).rsplit('.')[0].replace(' ', '_').replace(':', '.')

torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
np.random.seed(0)


def train(batch_size,
          epochs,
          bayes,
          base_model,
          dataset,
          valid_size=5000,
          lr=3e-4,
          label_smoothing=0.0,
          lambda_uce=0.0,
          lambda_mmce=0.0,
          gpu=0):

    print(start_time)
    snapshot_path = "/media/fastdata/laves/iclr_snapshots"

    device = torch.device('cuda:'+str(gpu) if torch.cuda.is_available() else 'cpu')

    assert bayes in [
        'none', 'mcdrop', 'gaussdrop', 'bbb', 'swag'
    ]

    assert base_model in [
        'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
        'densenet121', 'densenet169', 'densenet201',
    ]
    assert dataset in ['cifar10', 'cifar100', 'svhn']

    args = locals()

    for key, val in args.items():
        print(key, '=', val)

    os.makedirs("../snapshots", exist_ok=True)

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]

    if dataset == 'cifar100':
        num_classes = 100

        train_set = datasets.CIFAR100('../data', train=True, download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.RandomErasing(p=0.5),
                                          transforms.Normalize(mean=mean, std=std)
                                      ]))
        valid_set = datasets.CIFAR100('../data', train=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=mean, std=std)]))
        train_indices = torch.load('./train_indices_cifar100.pth')
        valid_indices = torch.load('./valid_indices_cifar100.pth')
    elif dataset == 'cifar10':
        num_classes = 10

        train_set = datasets.CIFAR10('../data', train=True, download=True,
                                     transform=transforms.Compose([
                                         transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.RandomErasing(p=0.5),
                                         transforms.Normalize(mean=mean, std=std)
                                     ]))
        valid_set = datasets.CIFAR10('../data', train=True,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=mean, std=std)
                                     ]))
        train_indices = torch.load('./train_indices_cifar10.pth')
        valid_indices = torch.load('./valid_indices_cifar10.pth')
    elif dataset == 'svhn':
        num_classes = 10

        train_set = datasets.SVHN('../data', split='train', download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                                      transforms.ToTensor(),
                                      transforms.RandomErasing(p=0.5),
                                      transforms.Normalize(mean=mean, std=std)
                                  ]))
        valid_set = datasets.SVHN('../data', split='train',
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=mean, std=std)
                                  ]))
        train_indices = torch.load('./train_indices_svhn.pth')
        valid_indices = torch.load('./valid_indices_svhn.pth')
    else:
        assert False

    # indices = torch.randperm(len(train_set))
    # train_indices = indices[:len(indices) - valid_size]
    # valid_indices = indices[len(indices) - valid_size:]
    # torch.save(train_indices, './train_indices_' + dataset + '.pth')
    # torch.save(valid_indices, './valid_indices_' + dataset + '.pth')

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(train_indices), pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(valid_indices), pin_memory=True)

    if bayes == 'none':
        from models import FrequentistNet
        net = FrequentistNet(num_classes, base_model).to(device)
    elif bayes == 'mcdrop':
        from models import MonteCarloDropoutNet
        net = MonteCarloDropoutNet(num_classes, base_model).to(device)
    elif bayes == 'gaussdrop':
        from models import GaussianDropoutNet
        net = GaussianDropoutNet(num_classes, base_model).to(device)
    elif bayes == 'bbb':
        from models import BBBNet
        net = BBBNet(num_classes, base_model).to(device)
    elif bayes == 'swag':
        from models import SWAGNet
        net = SWAGNet(num_classes, base_model).to(device)
    else:
        assert False

    optimizer_net = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-6)
    lr_scheduler_net = optim.lr_scheduler.ReduceLROnPlateau(optimizer_net, patience=20, factor=0.1)

    train_losses = []
    train_accuracies = []
    train_uces = []

    valid_losses = []
    valid_uces = []
    valid_accuracies = []

    for e in range(epochs):
        net.train()

        epoch_train_loss = []
        epoch_train_nll = []
        epoch_train_uce = []
        epoch_train_targets = []
        epoch_train_logits = []
        is_best = False

        print("lr =", optimizer_net.param_groups[0]['lr'])
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer_net.zero_grad()
            logits = net(data)

            xent = F.cross_entropy(logits, target)
            kll = kl_loss(logits)
            uce = uceloss(logits.softmax(dim=1), target)[0]
            # uce = calibration_mmce_w_loss_entr(logits.softmax(dim=1), target)
            mmce = calibration_mmce_w_loss(logits.softmax(dim=1), target).sqrt()

            loss = xent + label_smoothing*kll + lambda_uce*uce + lambda_mmce*mmce
            loss.backward()

            epoch_train_loss.append(loss.item())
            epoch_train_nll.append(xent.item())
            epoch_train_uce.append(lambda_uce*uce.item())
            epoch_train_targets.append(target.detach())
            epoch_train_logits.append(logits.detach())
            optimizer_net.step()

        epoch_train_targets = torch.cat(epoch_train_targets, dim=0)
        epoch_train_logits = torch.cat(epoch_train_logits, dim=0)
        epoch_train_loss = np.nanmean(epoch_train_loss)
        epoch_train_nll = np.nanmean(epoch_train_nll)
        epoch_train_uce = np.mean(epoch_train_uce)
        train_uce = uceloss(epoch_train_logits.softmax(dim=1), epoch_train_targets)[0].item()
        epoch_train_acc = accuracy(epoch_train_logits, epoch_train_targets)
        lr_scheduler_net.step(epoch_train_loss)

        net.eval()
        epoch_valid_loss = []
        epoch_valid_nll = []
        epoch_valid_uce = []
        epoch_valid_targets = []
        epoch_valid_logits = []

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(tqdm(valid_loader)):
                data, target = data.to(device), target.to(device)
                logits = net(data)

                xent = F.cross_entropy(logits, target)
                kll = kl_loss(logits)
                uce = uceloss(logits.softmax(dim=1), target)[0]
                # uce = calibration_mmce_w_loss_entr(logits.softmax(dim=1), target)
                mmce = calibration_mmce_w_loss(logits.softmax(dim=1), target).sqrt()

                loss = xent + label_smoothing*kll + lambda_uce*uce + lambda_mmce*mmce

                epoch_valid_loss.append(loss.item())
                epoch_valid_nll.append(xent.item())
                epoch_valid_uce.append(lambda_uce*uce.item())
                epoch_valid_targets.append(target.detach())
                epoch_valid_logits.append(logits.detach())

        epoch_valid_targets = torch.cat(epoch_valid_targets, dim=0)
        epoch_valid_logits = torch.cat(epoch_valid_logits, dim=0)
        epoch_valid_loss = np.nanmean(epoch_valid_loss)
        epoch_valid_nll = np.nanmean(epoch_valid_nll)
        epoch_valid_uce = np.mean(epoch_valid_uce)
        valid_uce = uceloss(epoch_valid_logits.softmax(dim=1), epoch_valid_targets)[0].item()
        epoch_valid_acc = accuracy(epoch_valid_logits, epoch_valid_targets)

        print(f"Epoch {e}")
        print("    loss: {:.4f},      acc: {:.4f},     uce: {:.4f}"
              .format(epoch_train_loss,
                      epoch_train_acc,
                      train_uce,
                      ))
        print(f"     nll: {epoch_train_nll:.4f},     ucel: {epoch_train_uce:.4f}")

        print("val_loss: {:.4f},  val_acc: {:.4f}, val_uce: {:.4f}"
              .format(epoch_valid_loss,
                      epoch_valid_acc,
                      valid_uce,
                      ))
        print(f" val_nll: {epoch_valid_nll:.4f}, val_ucel: {epoch_valid_uce:.4f}")

        # save epoch losses
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)
        valid_losses.append(epoch_valid_loss)
        valid_uces.append(epoch_valid_uce)
        valid_accuracies.append(epoch_valid_acc)

        if valid_accuracies[-1] == np.max(valid_accuracies):
            is_best = True

        if is_best:
            filename = f"{snapshot_path}/{bayes}_{base_model}_{dataset}_best_{start_time}.pth.tar"
            print("Saving best weights so far with val_loss: {:4f}".format(valid_losses[-1]))
            torch.save({
                'epoch': e,
                'state_dict': net.state_dict(),
                'optimizer': optimizer_net.state_dict(),
                'train_losses': train_losses,
                'train_accs': train_accuracies,
                'val_losses': valid_losses,
                'val_uces': valid_uces,
                'val_accs': valid_accuracies,
            }, filename)

        if bayes == 'swag' and e == epochs-1:
            net.train()
            optimizer_swag = optim.Adam(net.parameters(), lr=3e-6, weight_decay=1e-6)

            net.train_swag(train_loader, optimizer_swag, device)

            epoch_valid_loss = []
            epoch_valid_nll = []
            epoch_valid_targets = []
            epoch_valid_logits = []

            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(tqdm(valid_loader)):
                    data, target = data.to(device), target.to(device)
                    logits = net(data)

                    xent = F.cross_entropy(logits, target)
                    loss = xent

                    epoch_valid_loss.append(loss.item())
                    epoch_valid_targets.append(target.detach())
                    epoch_valid_logits.append(logits.detach())

            epoch_valid_targets = torch.cat(epoch_valid_targets, dim=0)
            epoch_valid_logits = torch.cat(epoch_valid_logits, dim=0)
            epoch_valid_loss = np.nanmean(epoch_valid_loss)
            epoch_valid_acc = accuracy(epoch_valid_logits, epoch_valid_targets)
            print(f"SWAG: val_loss: {epoch_valid_loss:.4f}, val_acc: {epoch_valid_acc:.4f}")

        if e == epochs-1:
            filename = f"{snapshot_path}/{bayes}_{base_model}_{dataset}_{e}_{start_time}.pth.tar"
            print("Saving weights at epoch {:d}".format(e))
            torch.save({
                'epoch': e,
                'state_dict': net.state_dict(),
                'optimizer': optimizer_net.state_dict(),
                'train_losses': train_losses,
                'train_accs': train_accuracies,
                'val_losses': valid_losses,
                'val_uces': valid_uces,
                'val_accs': valid_accuracies,
            }, filename)


if __name__ == '__main__':
    fire.Fire(train)
