import fire
import torch
from torchvision import datasets, transforms, models
from torch import optim
from time import time
from torch import nn
from L0_regularization.models import L0MLP
from lenet import FF, derandomize, get_device
from itertools import islice
import sys


def activations_vgg(model, img):
    hidden = model.avgpool(model.features(img)).view(-1, 25088)
    for i in range(10):
        hidden = model.classifier[i](hidden)
        if i in [0, 3, 6, 9]:
            yield hidden


def get_layer_output_vgg(model, img, layer):
    return next(islice(activations_vgg(model, img), layer - 1, None))


def load_vgg(path, mode="eval", dropout=0, device=None):
    """ Function to load the model
    Args:
    - path: model path
    returns - loaded model in evaluation mode
    """
    params = torch.load(path, map_location=device)
    hidden_sizes = [params['classifier.0.weight'].shape[0], params['classifier.3.weight'].shape[0], params['classifier.6.weight'].shape[0]]
    print(f'Found a model with layer sizes {hidden_sizes}', file=sys.stderr)

    model = models.vgg16(pretrained=True)
    clf = FF([25088] + hidden_sizes + [10], dropout=dropout)
    model.classifier = clf
    model.load_state_dict(torch.load(path, map_location=device))
    if mode == "eval":
        model.eval()
    else:
        model.train()
    return model


def get_layer_params_vgg(param_dict, layer, include_qzloga=False):
    # Layers = 1, n = 0; Layer=2, n = 3; Layer=3, n=6
    n = 3 * (layer - 1)
    if include_qzloga:
        return param_dict[f'classifier.{n}.weight'], param_dict[f'classifier.{n}.bias'], param_dict[f'classifier.{n}.mask'], param_dict[f'classifier.{n}.qz_loga']
    else:
        return param_dict[f'classifier.{n}.weight'], param_dict[f'classifier.{n}.bias'], param_dict[f'classifier.{n}.mask']


def set_layer_params_vgg(param_dict, layer, weights, biases, mask=None):
    # Layers = 1, n = 0; Layer=2, n = 3; Layer=3, n=6
    n = 3 * (layer - 1)
    mask = mask if mask is not None else torch.ones(weights.shape)
    (param_dict[f'classifier.{n}.weight'],
     param_dict[f'classifier.{n}.bias'],
     param_dict[f'classifier.{n}.mask']) = weights, biases, mask


def cifar10_dataset(data_loc):
    transform = transforms.Compose([transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # prepare the dataset
    trainset = datasets.CIFAR10(data_loc, download = False, train = True, transform = transform)
    valset = datasets.CIFAR10(data_loc, download = False, train = False, transform = transform)

    # dataloaders
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = 32, shuffle = True)
    valloader = torch.utils.data.DataLoader(valset, batch_size = 32, shuffle = True)

    return trainloader, valloader


def train_vgg(ff_hidden_sizes, out_dir, data_loc, epochs:int=15, lr:float =0.001, momentum:float=0.9, init_chkpt=None, L1=0, dropout=0, L0=0, freeze_convolutions=True, optimizer="SGD"):
    """ Train VGG16 on CIFAR10 using the given hyper-parameters. Analogous to scripts/lenet.py, which may be used for reference.

    Args:
    - ff_hidden_sizes: List[int] - the hidden sizes of the three final dense layers of VGG16. In the original model, this is 4096,4096, 1000, but this will change after pruning.
    - out_dir: str - the location to store the trained model
    - data_loc: str - the directory location of the cifar-10 data, provided by scripts/cifar_data.py
    - epochs: int - the number of epochs to train
    - lr: float - the learning rate
    - momentum - the momentum to use with the optimizer, if applicable
    - init_chkpt - the intial weights to use for the model, if continuing training
    - L1 - the L1 regularization to be applied to the weights
    - dropout - the dropout regularization to be applied to the feed-forward layers
    """

    trainloader, valloader = cifar10_dataset(data_loc)

    device = get_device()

    if init_chkpt:
        assert not ff_hidden_sizes
        print(f"Loading chkpt from {init_chkpt}")
        vgg16_model = load_vgg(init_chkpt, mode="train", dropout=dropout)
    else:
        vgg16_model = models.vgg16(pretrained = True)
        if L0:
            assert dropout == 0 and L1 == 0
            # Note: N is the number of training examples. Lambda is divided by N.
            clf = L0MLP(25088, 10, layer_dims=ff_hidden_sizes, weight_decay=0, N=60000, lambas=(0, L0, 0), temperature=2./3.)
        else:
            clf = FF([25088] + ff_hidden_sizes + [10], dropout = 0)
        vgg16_model.classifier = clf
        vgg16_model.train()

    # freezing feature layers
    if freeze_convolutions:
        for param in vgg16_model.features.parameters():
            param.requires_grad = False

    print(vgg16_model)
    vgg16_model.to(device)

    # setting up loss criteria and optimizer
    criterion = nn.NLLLoss()
    images, labels = next(iter(trainloader))
    images, labels = images.to(device), labels.to(device)

    logps = vgg16_model(images)
    loss = criterion(logps, labels)

    params = vgg16_model.classifier.parameters() if freeze_convolutions else vgg16_model.parameters()

    if optimizer == "SGD":
        optimizer = optim.SGD(params, lr=lr, momentum=momentum)
    else:
        optimizer = optim.Adam(params, lr=lr)

    # training loop
    vgg16_model.train()

    time0 = time()
    for e in range(epochs):
        running_loss = 0
        running_regularized_loss = 0
        for images, labels in trainloader:

            images, labels = images.to(device), labels.to(device)

            # Training pass
            optimizer.zero_grad()

            output = vgg16_model(images)
            loss = criterion(output, labels)

            # L1 weight decay
            if L1:
                regularization_loss = 0
                for param in vgg16_model.classifier.parameters():
                    regularization_loss += torch.sum(torch.abs(param))
                running_regularized_loss += L1 * regularization_loss.item()
                loss += L1 * regularization_loss

            if L0:
                regularization_loss = vgg16_model.classifier.regularization()
                running_regularized_loss += regularization_loss.item()
                loss += regularization_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        else:
            print(f"Epoch {e} - Training loss: {(running_loss - running_regularized_loss)/len(trainloader)}")
            print(f"Regularizer loss: {(running_regularized_loss)/len(trainloader)}")

    print("\nTraining Time (in minutes) =",(time()-time0)/60)

    # evaluation loop
    vgg16_model.eval()

    if L0:
        print(f'Expected L0: {vgg16_model.classifier.get_exp_flops_l0()}')
        vgg16_model.classifier = derandomize(vgg16_model.classifier, [25088] + ff_hidden_sizes + [10]).to(device)

    correct_count, all_count = 0, 0
    for images,labels in valloader:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            logps = vgg16_model(images)

        pred_labels = logps.argmax(dim=1)
        correct_count += (pred_labels == labels).sum().item()
        all_count += pred_labels.shape[0]

    print("Number Of Images Tested =", all_count)
    print("\nModel Accuracy =", (correct_count/all_count))

    # saving the model
    torch.save(vgg16_model.state_dict(), out_dir)


if __name__ == "__main__":
    fire.Fire(train_vgg)
