#!/usr/bin/env python3
import torch
import scipy
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, RandomAffine
import copy
from proxies import effective_rank, meco, zen, synflow, fisher, grasp, snip, grad_norm, zico

class NeuralNetwork(nn.Module):
    def __init__(self, layers, activation_function_name, weight_initialization=None):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential()
        self.activations = {
             "ReLU": nn.ReLU,
             "Tanh": nn.Tanh,
             "Tanhshrink": nn.Tanhshrink,
             "SiLU": nn.SiLU,
        }
        self.weight_initialization = weight_initialization
        for i in range(len(layers)-2):
            self.layers.append(
                 nn.Linear(layers[i], layers[i+1], bias=True)
            )
            self.layers.append(
                 self.activations[activation_function_name]()
            )
        self.layers.append(
            nn.Linear(layers[-2], layers[-1], bias=True)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.layers(x)
        return logits
    
def weights_init_xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def weights_init_uniform(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    
def train_step(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to("cuda"), y.to("cuda")
        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5}]")

def get_performance(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to("cuda"), y.to("cuda")
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return correct, test_loss

def init_model(model):
    if args.weight_initialization == "xavier":
        model.apply(weights_init_xavier)
    elif args.weight_initialization == "uniform":
        model.apply(weights_init_uniform)
    elif args.weight_initialization == "default":
        for layer in model.children():
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()
    else:
        raise NotImplementedError(f"The scheme {args.weight_initialization} has not been implemented.")

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--activation_function")
parser.add_argument("--weight_initialization")
parser.add_argument("--seed", type=int)
parser.add_argument("--n_neurons", nargs="+", type=int)
args = parser.parse_args()

torch.manual_seed(args.seed)
np.random.seed(args.seed)

print("#"*50)
print(f"# {'Activation function':<21}: {args.activation_function:<24}#")
print(f"# {'Weight initialization':<21}: {args.weight_initialization:<24}#")
print(f"# {'Seed':<21}: {args.seed:<24}#")
print(f"# {'N_neurons:':<21}: {str(args.n_neurons):<24}#")
print(f"{'#'*50}\n\n\n")

train_transforms = Compose([
    RandomAffine(degrees=15, translate=(0.1,0.1), scale=(0.9, 1.1)),
    ToTensor(),
])
test_transforms = Compose([
    ToTensor(),
])

training_data = datasets.EMNIST(
    root="~/Documents/Torch_Dataset",
    train=True,
    download=False,
    split="balanced",
    transform=train_transforms,
)
test_data = datasets.EMNIST(
    root="~/Documents/Torch_Dataset",
    train=False,
    download=False,
    split="balanced",
    transform=test_transforms,
)
train_set, validation_set = torch.utils.data.random_split(training_data, [101520, 11280])

statistics = {

}


for measure in ["effective_rank", "meco_opt", "zico", "zen", "synflow", "fisher", "grasp", "snip", "grad_norm"]:
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    model = NeuralNetwork(args.n_neurons, args.activation_function, args.weight_initialization)
    train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=1, persistent_workers=True)
    # 16 is the batch size used here https://github.com/pym1024/SWAP/blob/main/correlation.py#L26
    swap_dataloader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=1, persistent_workers=True)
    nas_dataloader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=1, persistent_workers=True)
    zico_dataloader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=1, persistent_workers=True)
    init_model(model)
    if measure == "effective_rank":
        model.eval()
        statistics[measure] = effective_rank.get_average_score_effective_rank(model, train_dataloader, repetitions=32)
    elif measure == "meco_opt":
        inputs, _ = next(iter(nas_dataloader))
        statistics[measure] = meco.get_score(model, inputs, "cpu", "meco_opt")
    elif measure == "zico":
        statistics[measure] = zico.getzico(model, zico_dataloader, torch.nn.CrossEntropyLoss())
    elif measure == "zen":
        inputs, _ = next(iter(nas_dataloader))
        statistics[measure] = zen.compute_nas_score(None, model, 1e-2, inputs.shape[2], 16, 32)["avg_nas_score"]
    elif measure == "synflow":
        statistics[measure] = synflow.compute_nas_score(model, nas_dataloader)
    elif measure == "fisher":
        statistics[measure] = fisher.compute_nas_score(model, nas_dataloader)
    elif measure == "grasp":
        statistics[measure] = grasp.compute_nas_score(model, nas_dataloader, 47)
    elif measure == "snip":
        statistics[measure] = snip.compute_nas_score(model, nas_dataloader)
    elif measure == "grad_norm":
        statistics[measure] = grad_norm.compute_nas_score(model, nas_dataloader)
    else:
        raise NotImplementedError(f"Measure {measure} is not implemented!")
    
print("Proxies".center(100, "#"))
print(statistics)
print("#"*100)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

batch_size = 64
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=10, persistent_workers=True, prefetch_factor=100)
validation_dataloader = DataLoader(validation_set, batch_size=batch_size, shuffle=True, num_workers=10, persistent_workers=True, prefetch_factor=100)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

model = NeuralNetwork(args.n_neurons, args.activation_function)
init_model(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

max_epochs = 5000
patience = 10

model.to("cuda")
early_stopping_counter = 0
best_val_loss = float("inf")
best_model = copy.deepcopy(model.state_dict())
for epoch in range(max_epochs):
    print(f"{f'Epoch {epoch}':>17}")
    print("-"*29)
    train_step(train_dataloader, model, loss_fn, optimizer)
    print(f"{'-'*29}")

    val_accuracy, val_loss = get_performance(validation_dataloader, model, loss_fn)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model.state_dict())
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    print(f"Validation loss: {val_loss:.7f}")
    if early_stopping_counter == patience:
        print(f"No improvement for {early_stopping_counter} epochs. Stop training.\n\n")
        break
    elif early_stopping_counter > 0:
        print(f"No improvement for {early_stopping_counter} epochs.\n\n")

model.load_state_dict(best_model)
test_accuracy, test_loss = get_performance(test_dataloader, model, loss_fn)
print(f"Test accuracy: {test_accuracy*100}, Test loss: {test_loss:.7f}")
