import fire
import torch
from torch import nn
from torch import optim
from lenet import load_lenet, get_layer_output_lenet, mnist_dataset
from vgg import load_vgg, get_layer_output_vgg, cifar10_dataset


def layer_kd(pruned_model, orig_model, out_dir, data_loc, model_type, layer, epochs=10, lr=0.01, momentum=0.5):

    if model_type == "lenet":
        dataset, load, get_layer_output, output_layer = mnist_dataset, load_lenet, get_layer_output_lenet, 3
    elif model_type == "vgg":
        dataset, load, get_layer_output, output_layer = cifar10_dataset, load_vgg, get_layer_output_vgg, 4
    else:
        raise Exception(f"Unknown model type {model_type}")

    trainloader, _ = dataset(data_loc)

    # Layer details for the neural network
    if torch.cuda.is_available():
        print("Using CUDA")
        import os
        print(os.environ.get("CUDA_VISIBLE_DEVICES"))
        device = torch.device('cuda')
    else:
        print("Not Using CUDA")
        device = torch.device('cpu')

    criterion = nn.MSELoss()
    orig_model = load(orig_model).to(device)
    pruned_model = load(pruned_model).to(device)

    optimizer = optim.SGD(pruned_model.parameters(), lr=lr, momentum=momentum)

    for e in range(epochs):
        running_loss = 0
        for images, labels in trainloader:
            images = images if model_type == "vgg" else images.view(images.shape[0], -1)
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                orig_hidden = get_layer_output(orig_model, images, layer)
            pruned_hidden = get_layer_output(pruned_model, images, layer)

            orig_hidden = orig_hidden / torch.norm(orig_hidden, dim=1, keepdim=True)
            pruned_hidden = pruned_hidden / torch.norm(pruned_hidden, dim=1, keepdim=True)

            optimizer.zero_grad()
            loss = criterion(orig_hidden, pruned_hidden)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        else:
            print(f"Epoch {e} - Training loss: {running_loss}")

    torch.save(pruned_model.state_dict(), out_dir)

if __name__ == "__main__":
    fire.Fire(layer_kd)
