import torch
import torch.nn as nn
from mnist.model_mnist import ShellNetwork, GCNNShell_mnist, Simple_Shell, HyperNetwork, FunctionalFullNetwork, \
    LowRankHyperNetwork, HyperNetwork_Head
from utils.model_tools import ShellParser
from utils.math_tools import RotateTransform
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np
import os
import datetime
import random
from utils.model_tools import plot
from augmentation import calculate_weight_similarity,similarity
from utils.math_tools import unnormalize_image

def mnist_train(lora=False, Load_Previous=False, n=0, shell_choice='gcnn', l_r=0.00075,
                matrix_dim=2, inter_dim=4,  visualize_filters=False, epoch=50,
                head=False, shared_choice = 0, test_aug=False, reflection=True):
    """
    :param lora: Low-rank approximation. Matrix_dim = decomposition dimension, Inter_dim = intermediate dimension.
    :param n: group size = 4*2^n
    :param shell_choice: gcnn, simple or custom.
    """

    if test_aug:
        Load_Previous = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    show_info = True  # Whether to show the network information or not.
    if shell_choice == 'simple':
        shell = Simple_Shell()
    elif shell_choice == 'gcnn':
        shell = GCNNShell_mnist()
    elif shell_choice == 'custom':
        shell = ShellNetwork()
    else:
        raise ValueError("shell_choice should be 'simple', 'gcnn' or 'custom'.")

    root = 'saved_pth/mnist/'
    if os.path.exists(root):
        pass
    else:
        os.makedirs(root)

    input_size = 28 * 28
    parameters, layers, layers_info = ShellParser(shell)
    total_params = sum(parameters)
    if head:
        hyper = HyperNetwork_Head(shell, lora=lora, inter=inter_dim,
                                  shared_choice=shared_choice)
        if lora:
            hyper_save_location = 'saved_pth/mnist/head_choice{}_inter{}_lora_hyp_{}_{}.pth'.format(shared_choice,
                                                                                                     inter_dim,
                                                                                                     shell_choice, n)
            full_save_location = 'saved_pth/mnist/head_choice{}_inter{}_lora_full_{}_{}.pth'.format(shared_choice,
                                                                                                     inter_dim,
                                                                                                     shell_choice, n)
        else:
            hyper_save_location = 'saved_pth/mnist/head_choice{}_hyp{}_{}.pth'.format(shared_choice,shell_choice, n)
            full_save_location = 'saved_pth/mnist/head_choice{}_full{}_{}.pth'.format(shared_choice, shell_choice, n)
    elif lora:
        print("Using low-rank approximation on the whole output, not individual heads.")
        hyper = LowRankHyperNetwork(input_size, shell, matrix_dim=matrix_dim, intermediate_dim=inter_dim).to(device)
        print("The chosen matrix info: ({}x{}) and ({}x{})".format(hyper.d, hyper.intermediate_dim,
                                                                    hyper.intermediate_dim,
                                                                    hyper.e))
        hyper_save_location = 'saved_pth/mnist_pth/{}_hyper_low_{}_{}_{}.pth'.format(shell_choice, n, hyper.d,
                                                                                     hyper.intermediate_dim)
        full_save_location = 'saved_pth/mnist_pth/{}_full_low_{}_{}_{}.pth'.format(shell_choice, n, hyper.d,
                                                                                   hyper.intermediate_dim)

    else:
        hyper = HyperNetwork(input_size, shell).to(device)
        hyper_save_location = 'saved_pth/mnist_pth/{}_hyper_network_{}.pth'.format(shell_choice, n)
        full_save_location = 'saved_pth/mnist_pth/{}_full_network_{}.pth'.format(shell_choice, n)

    print("saved location of hypernet", hyper_save_location)
    if Load_Previous:
        print("Loading saved model for hypernetwork.")
        hyper.load_state_dict(torch.load(hyper_save_location))

    learning_rate = l_r
    print("Training the network with learning rate:", learning_rate)
    print("saved location", full_save_location)
    # Two ways of initializing the full network:
    # shell.layer.data = parameter would erase grad information，functional keeps grad information,
    # Thus we choose the second one.
    network = FunctionalFullNetwork(hyper, shell, n, head=head).to(device)
    if show_info:
        print("Showing infos.")
        print("Here are the parameters of the shell network:\n", parameters)
        print("Recognized layers:", layers)
        print("Recognized layers info:", layers_info)

    # Load the MNIST dataset
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=32, shuffle=True)
    test_data = datasets.MNIST('data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=32, shuffle=True)

    rot_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(), transforms.RandomRotation(180), transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)

    ninety_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(), RotateTransform(90), transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)

    ## This is the training set we are using, as it should be the same for rot_loader, given the 90 degree rotations.
    partial_range_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, transform=transforms.Compose([
            transforms.ToTensor(), transforms.RandomRotation(45), transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)

    if test_aug:
        transformed_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, transform=transforms.Compose([
            transforms.ToTensor(), transforms.RandomAffine(degrees=180, translate=(0.3, 0.3))
            , transforms.Normalize((0.1307,), (0.3081,))
        ])), batch_size=32, shuffle=True)

    if visualize_filters:
        data_name = 'mnist'
        print("Visualizing filters.")
        while True:
            choice = int(input("Choose a number of input to visualize: "))
            root = 'visualize/{}/{}/'.format(data_name, choice)
            print("It will be saved in the directory: ", root)
            image = test_data[choice][0].unsqueeze(0).to(device)
            visualize(image, network, root)
            print("choose a different input image?: (y/n)")
            if input().lower() == "n":
                break
        import sys
        sys.exit()
        return

    if test_aug:
        print("Testing augmentation for one of the previous rebuttals.")
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in transformed_loader:
                images = images.to(device)
                outputs = network(images)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy = 100 * correct / total
        print("The accuracy of the augmented test set is: {:.2f}%".format(accuracy))
        # compute the accuarcy on the original test set
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in rot_loader:
                images = images.to(device)
                outputs = network(images)
                if reflection:
                    images_reflected = images.flip(2)
                    outputs_reflected = network(images_reflected)
                    outputs_reflected = outputs_reflected.flip(2)
                    outputs = (outputs + outputs_reflected) / 2
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy_rot = 100 * correct / total
        print("The accuracy of the original test set is: {:.2f}%".format(accuracy_rot))
        return

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(network.hypernetwork.parameters(), lr=learning_rate)
    # print("Started training.")
    current_time = datetime.datetime.now().strftime("%m-%d %H:%M:%S")
    print("Starting Time:", current_time)
    # Train the network
    bar = tqdm.tqdm(np.arange(epoch))

    for counter in bar:
        train_loss = []
        # for _, (images, labels) in enumerate(train_loader):
        for _, (images, labels) in enumerate(partial_range_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = network(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        avg_loss = np.mean(train_loss)
        # print(counter, "th training finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        # Test the network
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                outputs = network(images)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy = 100 * correct / total
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in rot_loader:
                images = images.to(device)
                outputs = network(images)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy_rot = 100 * correct / total
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in ninety_loader:
                images = images.to(device)
                outputs = network(images)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy_ninety = 100 * correct / total
        print()
        # print(counter, "th test finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        bar.set_description(
            "Loss: {:.4f}, Test Acc: {:.2f}%, Test Acc Rot: {:.2f}%, Test Acc 90°: {:.2f}%".format(avg_loss, accuracy,
                                                                                                   accuracy_rot,
                                                                                                   accuracy_ninety))
    torch.save(network.state_dict(), full_save_location)
    torch.save(network.hypernetwork.state_dict(), hyper_save_location)
    print("Saving the regular training to network.pth and hyper_network.pth")
    print()

def visualize(image, network, root):
    original_root = root
    for index in range(8):
        choice = [0, index]
        print("First, the original image and its corresponding filter:")
        root = original_root + "{}/".format(choice[1])
        if not os.path.exists(root):
            os.makedirs(root)
        plot(image, "original image and its corresponding filter", saving_path=root + "original.png")
        cnn_filters, cnn_bias, linear_weights, linear_bias = network.filter_only(image, choice)
        for i, filters in enumerate(cnn_filters):
            plot(filters, "filters for {} layer.".format(i), saving_path=root + "original_filters_{}.png".format(i))
        print("finished plotting filters.")
        print("Now the 90 degree rotated image and its corresponding filter:")
        image_rotated = torch.rot90(image, 1, [-2, -1])
        plot(image_rotated, "90 degree rotated image and its corresponding filter",
             saving_path=root + "rotated.png")
        cnn_filters, cnn_bias, linear_weights, linear_bias = network.filter_only(image_rotated, choice)
        for i, filters in enumerate(cnn_filters):
            plot(filters, "filters for {} layer. (rot-input)".format(i), saving_path=root + "rot_filters_{}.png".format(i))
