import fire
import torch
from time import time
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch import optim
from L0_regularization.models import L0MLP
from torchvision import datasets, transforms
from itertools import islice
import math
import sys


def get_device():
    if torch.cuda.is_available():
        print("Using CUDA", file=sys.stderr)
        import os
        print(os.environ.get("CUDA_VISIBLE_DEVICES"), file=sys.stderr)
        return torch.device('cuda')
    else:
        print("Not Using CUDA", file=sys.stderr)
        return torch.device('cpu')

# Most of this is copied from torch.nn.Linear
class MaskedLinear(torch.nn.Module):

    def __init__(self, in_features, out_features):
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        # Note: no grads for masks, since we don't want them changing during training.
        # We do need the mask to be a parameter, however, if we want it in the state dict.
        self.mask = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
        # Only used for L0 pruning
        self.qz_loga = nn.Parameter(torch.zeros(self.in_features))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.bias, -bound, bound)
        init.ones_(self.mask)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        tmp_weights = self.mask * self.weight
        return F.linear(input, tmp_weights, self.bias)


def FF(layer_sizes, dropout):
    # Build a feed-forward network
    layers = []
    for idx, size in enumerate(layer_sizes[:-1]):
        layers.append(MaskedLinear(size, layer_sizes[idx+1]))
        if idx != len(layer_sizes) - 2:
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=dropout))

    return nn.Sequential(*layers, nn.LogSoftmax(dim=1))


def load_lenet(model_dir, mode="eval", dropout=0):
    params = torch.load(model_dir)
    hidden_sizes = [params['0.weight'].shape[0], params['3.weight'].shape[0]]
    print(f'Found a model with layer sizes {hidden_sizes}', file=sys.stderr)

    model = FF([784] + hidden_sizes + [10], dropout=dropout)
    model.load_state_dict(params)
    if mode == "eval":
        model.eval()
    else:
        model.train()
    return model


def activations_lenet(model, img):
    hidden = img
    for i in range(7):
        hidden = model[i](hidden)
        if i in [0, 3, 6]:
            yield hidden


def get_layer_output_lenet(model, img, layer):
    return next(islice(activations_lenet(model, img), layer - 1, None))


def get_layer_params_lenet(param_dict, layer, include_qzloga=False):
    # Layers = 1, n = 0; Layer=2, n = 3; Layer=3, n=6
    n = 3 * (layer - 1)
    if include_qzloga:
        return param_dict[f'{n}.weight'], param_dict[f'{n}.bias'], param_dict[f'{n}.mask'], param_dict[f'{n}.qz_loga']
    else:
        return param_dict[f'{n}.weight'], param_dict[f'{n}.bias'], param_dict[f'{n}.mask']


def set_layer_params_lenet(param_dict, layer, weights, biases, mask=None):
    # Layers = 1, n = 0; Layer=2, n = 3; Layer=3, n=6
    n = 3 * (layer - 1)
    mask = mask if mask is not None else torch.ones(weights.shape)
    param_dict[f'{n}.weight'], param_dict[f'{n}.bias'], param_dict[f'{n}.mask'] = weights, biases, mask


def mnist_dataset(data_loc):
    # Define a transform to normalize the data
    # Mean and std from pytorch example
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.Normalize((0.1307,), (0.3081,)),
                                ])

    # Download and load the training data
    trainset = datasets.MNIST(data_loc, download=False, train=True, transform=transform)
    valset = datasets.MNIST(data_loc, download=False, train=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True)
    valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)

    return trainloader, valloader


def train_mnist(hidden_sizes,
                out_dir,
                data_loc,
                epochs:int =10,
                lr:float =0.01,
                momentum:float=0.5,
                init_chkpt=None,
                L1=0,
                dropout=0,
                L0=0,
                optimizer="SGD"):


    # Download and load the training data
    trainloader, valloader = mnist_dataset(data_loc)


    if init_chkpt:
        assert not hidden_sizes
        print(f"Loading chkpt from {init_chkpt}")
        model = load_lenet(init_chkpt, mode="train", dropout=dropout)
    else:
        if L0:
            assert dropout == 0 and L1 == 0
            # Note: N is the number of training examples. Lambda is divided by N.
            model = L0MLP(784, 10, layer_dims=hidden_sizes, weight_decay=0, N=50000, lambas=(0, L0, 0), temperature=2./3.)
        else:
            model = FF([784] + hidden_sizes + [10], dropout=dropout)
        model.train()

    device = get_device()
    model = model.to(device)
    print(model)

    criterion = nn.NLLLoss()
    images, labels = next(iter(trainloader))
    images, labels = images.to(device), labels.to(device)
    images = images.view(images.shape[0], -1)

    logps = model(images)
    loss = criterion(logps, labels)

    if optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)

    time0 = time()
    for e in range(epochs):
        running_loss = 0
        running_regularized_loss = 0
        for images, labels in trainloader:
            # Flatten MNIST images into a 784 long vector
            images = images.view(images.shape[0], -1)
            images, labels = images.to(device), labels.to(device)

            # Training pass
            optimizer.zero_grad()

            output = model(images)
            loss = criterion(output, labels)

            if L1:
                regularization_loss = 0
                for param in model.parameters():
                    regularization_loss += torch.sum(torch.abs(param))
                running_regularized_loss += L1 * regularization_loss.item()
                loss += L1 * regularization_loss

            if L0:
                regularization_loss = model.regularization()
                running_regularized_loss += regularization_loss.item()
                loss += regularization_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        else:
            print(f"Epoch {e} - Training loss: {(running_loss - running_regularized_loss)/len(trainloader)}")
            print(f"Regularizer loss: {(running_regularized_loss)/len(trainloader)}")

    print("\nTraining Time (in minutes) =",(time()-time0)/60)

    # Move the model back to cpu for validation
    model = model.to('cpu')
    model.eval()

    if L0:
        print(f'Expected L0: {model.get_exp_flops_l0()}')
        model = derandomize(model, [784] + hidden_sizes + [10])

    correct_count, all_count = 0, 0
    for images,labels in valloader:
        for i in range(len(labels)):
            img = images[i].view(1, 784)

            with torch.no_grad():
                logps = model(img)

            logps = list(logps.numpy()[0])
            pred_label = logps.index(max(logps))
            true_label = labels.numpy()[i]
            if(true_label == pred_label):
                correct_count += 1
            all_count += 1

    print("Number Of Images Tested =", all_count)
    print("\nModel Accuracy =", (correct_count/all_count))

    torch.save(model.state_dict(), out_dir)

def derandomize(model, hidden_sizes):
    assert not model.training
    new_model = FF(hidden_sizes, dropout=0)
    for idx in range(len(hidden_sizes) - 1):
        new_model[idx * 3].weight = nn.Parameter(model.layers[idx].sample_weights().t())
        new_model[idx * 3].bias = model.layers[idx].bias
        new_model[idx * 3].qz_loga = model.layers[idx].qz_loga

    return new_model

if __name__ == "__main__":
    fire.Fire(train_mnist)
