#!/usr/bin/env python3
import torch
import scipy
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, RandomAffine
import copy
from collections import defaultdict

class NeuralNetwork(nn.Module):
    def __init__(self, layers, activation_function_name):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layer_stack = nn.Sequential()
        self.activations = {
             "ReLU": nn.ReLU,
             "Tanh": nn.Tanh,
             "Tanhshrink": nn.Tanhshrink,
             "SiLU": nn.SiLU,
        }

        for i in range(len(layers)-2):
            self.layer_stack.append(
                 nn.Linear(layers[i], layers[i+1], bias=True)
            )
            self.layer_stack.append(
                 self.activations[activation_function_name]()
            )
        self.layer_stack.append(
            nn.Linear(layers[-2], layers[-1], bias=True)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.layer_stack(x)
        return logits
    
def weights_init_xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def weights_init_uniform(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    
def train_step(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5}]")

def get_performance(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

def get_effective_rank(matrix, return_singular_values=False):
    # effective rank presented here https://ieeexplore.ieee.org/abstract/document/7098875
    S = torch.linalg.svdvals(matrix)
    if return_singular_values:
        singular_values = S.detach().clone()
    S /= torch.sum(S)
    erank = torch.e ** scipy.stats.entropy(S.detach())
    if return_singular_values:
        return np.nan_to_num(erank), singular_values
    return np.nan_to_num(erank)

def get_activations(model, layer, dataloader):
    activations = [torch.tensor([])]
    size = []
    def get_acctivation():
        def hook(model, input, output):
            size.append(output.shape[-1])
            if output.dim() > 2:
                output = torch.transpose(output, 1, 3).flatten(0, 2)
            activations[0] = torch.cat((activations[0], output), dim=0)
        return hook
    hook = model.layer_stack[layer].register_forward_hook(get_acctivation())
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        if activations[0].shape[0] >= activations[0].shape[1]:
            start = np.random.randint(0, activations[0].shape[0]//size[0] - 1) * size[0] if activations[0].shape[0]//size[0] - 1 > 0 else 0
            end = start + activations[0].shape[1]
            activations[0] = activations[0][start:end]
            break
    hook.remove()
    return activations[0]

def init_model(model):
    if args.weight_initialization == "xavier":
        model.apply(weights_init_xavier)
    elif args.weight_initialization == "uniform":
        model.apply(weights_init_uniform)
    elif args.weight_initialization != "default":
        raise NotImplementedError(f"The scheme {args.weight_initialization} has not been implemented.")

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--activation_function")
parser.add_argument("--weight_initialization")
parser.add_argument("--seed", type=int)
parser.add_argument("--n_neurons", nargs="+", type=int)
args = parser.parse_args()

torch.manual_seed(args.seed)
np.random.seed(args.seed)

model = NeuralNetwork(args.n_neurons, args.activation_function)

print("#"*50)
print(f"# {'Activation function':<21}: {args.activation_function:<24}#")
print(f"# {'Weight initialization':<21}: {args.weight_initialization:<24}#")
print(f"# {'Seed':<21}: {args.seed:<24}#")
print(f"# {'N_neurons:':<21}: {str(args.n_neurons):<24}#")
print(f"{'#'*50}\n\n\n")

train_transforms = Compose([
    RandomAffine(degrees=15, translate=(0.1,0.1), scale=(0.9, 1.1)),
    ToTensor(),
])
test_transforms = Compose([
    ToTensor(),
])

training_data = datasets.EMNIST(
    root="~/Documents/Torch_Dataset",
    train=True,
    download=False,
    split="balanced",
    transform=train_transforms,
)
test_data = datasets.EMNIST(
    root="~/Documents/Torch_Dataset",
    train=False,
    download=False,
    split="balanced",
    transform=test_transforms,
)
train_set, validation_set = torch.utils.data.random_split(training_data, [101520, 11280])

batch_size = 64
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
validation_dataloader = DataLoader(validation_set, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

ranks_activations = defaultdict(int)

NUM_REPEATS = 100
for repeat in range(NUM_REPEATS):
    init_model(model)
    for layer in range(len(model.layer_stack)):
        activations = get_activations(model, layer, train_dataloader)
        ranks_activations[layer] += get_effective_rank(activations) / NUM_REPEATS
total_rank = sum(ranks_activations.values())

print("Effective Rank".center(50, "#"))
print([round(rank, 3) for i, rank in ranks_activations.items()])
print(total_rank)
print(f"{'#'*50}\n\n")

torch.manual_seed(args.seed)
np.random.seed(args.seed)

max_epochs = 5000
patience = 10

early_stopping_counter = 0
best_val_loss = float("inf")
best_model = copy.deepcopy(model.state_dict())
for epoch in range(max_epochs):
    print(f"{f'Epoch {epoch}':>17}")
    print("-"*29)
    train_step(train_dataloader, model, loss_fn, optimizer)
    print(f"{'-'*29}")

    val_accuracy, val_loss = get_performance(validation_dataloader, model, loss_fn)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model.state_dict())
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    print(f"Validation loss: {val_loss:.7f}")
    if early_stopping_counter == patience:
        print(f"No improvement for {early_stopping_counter} epochs. Stop training.\n\n")
        break
    elif early_stopping_counter > 0:
        print(f"No improvement for {early_stopping_counter} epochs.\n\n")

model.load_state_dict(best_model)
test_accuracy, test_loss = get_performance(test_dataloader, model, loss_fn)
print(f"Test accuracy: {test_accuracy*100}, Test loss: {test_loss:.7f}")
