import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR

import cifar10_config
from tqdm import tqdm
import matplotlib.pyplot as plt


if __name__ == '__main__':
    model_type = 'cnn_map'
    print("Model Type: {}".format(model_type))
    disable_tqdm = False
    device = 'cuda'

    transform = Compose([RandomCrop(32, padding=4), RandomHorizontalFlip(),
                         ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
    train_data = CIFAR10("./data/", train=True, transform=transform, download=True)

    n_seeds = 10
    n_classes = 10
    n_channels = 3
    n_height = 32
    n_width = 32
    n_features = n_channels * n_height * n_width
    max_precision = 50000
    args = (n_classes, n_features, max_precision)
    config = getattr(cifar10_config, model_type)
    kwargs = config['kwargs']

    initial_seed = 12345
    for seed in range(n_seeds):
        print()
        print("Seed: {}/{}".format(seed + 1, n_seeds))
        torch.manual_seed(initial_seed + seed)
        model = config['model'](*args, **kwargs)
        model.to(device)

        train_data_loader = DataLoader(train_data, batch_size=config['n_batch'], shuffle=True)
        optimizer = SGD(model.parameters(), lr=config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay'])
        lr_scheduler = MultiStepLR(optimizer, milestones=config['milestones'], gamma=config['gamma'])
        nll_trace = []
        fkl_trace = []
        model.train()
        for i in range(config['epochs']):
            print()
            print("Epoch: {}/{}".format(i + 1, config['epochs']))
            if disable_tqdm:
                iterator = train_data_loader
            else:
                iterator = tqdm(train_data_loader)
            for x, y in iterator:
                current_batch_size = x.shape[0]
                optimizer.zero_grad()
                raw_logits = model(x.to(device))
                nll = model.nll_loss(raw_logits, y.to(device))
                nll_trace.append(nll.item())
                loss = nll
                if callable(getattr(model, 'fkl_loss', None)):
                    fkl = model.fkl_loss(raw_logits) / config['n_batch']
                    fkl_trace.append(fkl.item())
                    loss += fkl
                loss.backward()
                optimizer.step()
            lr_scheduler.step()
        model.to('cpu')
        torch.save(model.state_dict(), "./state_dicts/CIFAR10_{}_{}.pt".format(model_type, seed))
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.plot(nll_trace, color='tab:blue')
        ax2 = ax1.twinx()
        ax2.plot(fkl_trace, color='tab:orange')
        plt.title("CIFAR10, {}".format(model_type))
        plt.ylabel("Loss")
        plt.xlabel("Batch Iterations, n_batch = {}, epochs = {}".format(config['n_batch'], config['epochs']))
        plt.savefig("./plots/CIFAR10_{}_{}.png".format(model_type, seed))
