import torch
import torch.nn as nn
from utils.math_tools import unnormalize_image
import os
import torchvision.models as models
from cifar.model_cifar import (ShellNetwork, GCNNShell_cifar, Simple_Shell, LowRankHyper,
                               HyperNetwork_cifar, HyperNetwork_Custom, FunctionalFullNetwork
                                 , HyperNetwork_Head, ResNet18_cifar10)
from utils.model_tools import ShellParser, param_matching_loss
from utils.math_tools import RotateTransform, find_closest_divisor
from utils.model_tools import plot
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np
import datetime
import random

from constraint import GroupConvolution_layer, Lift_layer

class equivariant_maxpool(nn.Module):
    def __init__(self, kernel_size=2, stride=2):
        super(equivariant_maxpool, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        batch = x.shape[0]
        channels = x.shape[1]
        group_size = x.shape[2]
        feature_size = x.shape[3]
        x = x.view(batch, -1, feature_size, feature_size)
        x = self.maxpool(x)
        new_feature_size = x.shape[-1]
        x = x.view(batch, channels, group_size, new_feature_size, new_feature_size)
        return x

class Simple_Shell_GCNN(nn.Module):
    def __init__(self, num_classes=10, group=4):
        super(Simple_Shell_GCNN, self).__init__()
        self.conv1a = Lift_layer(group, 3, 128, kernel_size=3, padding=1)
        self.lrelu1 = nn.LeakyReLU(negative_slope=0.1)
        self.conv1b = GroupConvolution_layer(group, 128, 128, kernel_size=3, padding=1)
        self.lrelu2 = nn.LeakyReLU(negative_slope=0.1)
        self.conv1c = GroupConvolution_layer(group, 128, 128, kernel_size=3, padding=1)
        self.lrelu3 = nn.LeakyReLU(negative_slope=0.1)
        self.pool1 = equivariant_maxpool(kernel_size=2, stride=2)
        self.drop1 = nn.Dropout(p=0.5)
        self.conv2a = GroupConvolution_layer(group, 128, 256, kernel_size=3, padding=1)
        self.lrelu4 = nn.LeakyReLU(negative_slope=0.1)
        self.conv2b = GroupConvolution_layer(group, 256, 256, kernel_size=3, padding=1)
        self.lrelu5 = nn.LeakyReLU(negative_slope=0.1)
        self.conv2c = GroupConvolution_layer(group, 256, 256, kernel_size=3, padding=1)
        self.lrelu6 = nn.LeakyReLU(negative_slope=0.1)
        self.pool2 = equivariant_maxpool(kernel_size=2, stride=2)
        self.drop2 = nn.Dropout(p=0.5)
        self.conv3a = GroupConvolution_layer(group, 256, 512, kernel_size=3)
        self.lrelu7 = nn.LeakyReLU(negative_slope=0.1)
        self.conv3b = GroupConvolution_layer(group, 512, 256, kernel_size=1)
        self.lrelu8 = nn.LeakyReLU(negative_slope=0.1)
        self.conv3c = GroupConvolution_layer(group, 256, 128, kernel_size=1)
        self.lrelu9 = nn.LeakyReLU(negative_slope=0.1)
        self.globalpool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(128, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, place_holder=False):
        x = self.conv1a(x)
        x = self.lrelu1(x)
        x = self.conv1b(x)
        x = self.lrelu2(x)
        x = self.conv1c(x)
        x = self.lrelu3(x)
        x = self.pool1(x)
        x = self.drop1(x)
        x = self.conv2a(x)
        x = self.lrelu4(x)
        x = self.conv2b(x)
        x = self.lrelu5(x)
        x = self.conv2c(x)
        x = self.lrelu6(x)
        x = self.pool2(x)
        x = self.drop2(x)
        x = self.conv3a(x)
        x = self.lrelu7(x)
        x = self.conv3b(x)
        x = self.lrelu8(x)
        x = self.conv3c(x)
        x = self.lrelu9(x)
        x = self.globalpool(x)
        x = x.view(-1, 128)
        x = self.fc(x)
        x = self.softmax(x)
        return x

def visualize(layers, layers_info, image, network, root):
    original_root = root
    while True:
        choice = int(input("Choose a channel to visualize: "))
        print("First, the original image and its corresponding filter:")
        root = original_root + "{}/".format(choice)
        if not os.path.exists(root):
            os.makedirs(root)
        cnn_filters, cnn_bias, linear_weights, linear_bias = network.filter_only(image, choice)
        for i, filters in enumerate(cnn_filters):
            plot(filters, saving_path=root + "original_filters_{}.png".format(i))
        print("finished plotting filters.")
        print("Now the 90 degree rotated image and its corresponding filter:")
        image_rotated = torch.rot90(image, 1, [-2, -1])
        cnn_filters, cnn_bias, linear_weights, linear_bias = network.filter_only(image_rotated, choice)
        for i, filters in enumerate(cnn_filters):
            plot(filters,  saving_path=root + "rot_filters_{}.png".format(i))
        print("continue to choose a different channel?: (y/n)")
        input_str = input()
        if input_str.lower() == "n":
            break

def cifar_train(lora=False, load_previous=False, n=0, shell_choice='resnet', set_up=0,
                l_r=0.0005, b_s=128, inter_dim=1, weight_decay=0,  head=False,
                shared_choice = 0, shared_layer=3, reflection=True,
                pretrained=False, data_choice=False, seed = 73, visualize_filters=False):
    """
    set_up = 0: custom choices, will ask the user to choose
    set_up = 1: on resnet18, using average full network, and using easier ppg.
    set_up = 2: on resnet18, using batch of individual networks, and using easier ppg.
    set_up = 4: on 13-layer CNN, using one full network, and using easier ppg.
    set_up = 5: on 13-layer CNN, using batch full network, and using easier ppg.
    """
    print("setting seed.")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    # otherwise we might get CUDNN status not supported warning message.

    if head:
        assert shared_layer in [2,3], " only 2 or 3 for shared layers."
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hyper_choice = 'simple'
    full_type = "one"

    if set_up != 0:
        print("Set up = {}.".format(set_up))
    if set_up == 1:
        print("Using resnet, using one full network, and using easier ppg.")
        shell_choice = "resnet18"
    elif set_up == 3 or set_up == 5:
        shell_choice = "simple"
        if set_up == 3:
            print("Using 13-layer CNN, batch of full networks, and using easier ppg.")
            full_type = 'batch'
        elif set_up == 5:
            print("Using 13-layer CNN, one full network, and using easier ppg.")
            full_type = 'one'
    else:
        raise ValueError("please follow readme.")
    if shared_choice == 1:
        print("Using medium difficulty for shared")
    elif shared_choice == 2:
        print("Using hard difficulty for shared")
    print()
    batch = b_s
    if set_up == 0:
        user_input = input("Please choose simple, resnet18, gcnn or custom: ")
        if user_input.lower() == "simple":
            shell_choice = "simple"
        elif user_input.lower() == "resnet18":
            shell_choice = "resnet18"
        elif user_input.lower() == "gcnn":
            shell_choice = "gcnn"
        elif user_input.lower() == "custom":
            shell_choice = "custom"
    if data_choice:
        print("Using CIFAR100 dataset.")
        out_class = 100
    else:
        print("Using CIFAR10 dataset.")
        out_class = 10
    if set_up == 10:
        print("using base model")
        shell = Simple_Shell()
    elif set_up == 11:
        print("using GCNN")
        shell = Simple_Shell_GCNN()
    # elif set_up == 11:
    #     print("using GCNN")
    #     shell = ResNet18_cifar10(num_classes=out_class)
    elif shell_choice == 'simple':
        shell = Simple_Shell(num_classes=out_class)
    elif shell_choice == 'gcnn':
        shell = GCNNShell_cifar()
    elif shell_choice == 'custom':
        shell = ShellNetwork()
    elif shell_choice == 'resnet18':
        # shell = models.resnet18(weights=None, num_classes=10)
        shell = ResNet18_cifar10(num_classes=out_class)
        print("Don't use the original resnet18, as it is designed for ImageNet.")
    elif shell_choice == 'resnet34':
        shell = models.resnet34(num_classes=out_class)
    else:
        raise ValueError("shell_choice should be 'simple', 'gcnn' or 'custom'.")

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])
    force_stride = False
    parameters, layers, layers_info = ShellParser(shell, force_stride=force_stride)
    linear_info = ShellParser(shell, 2)
    print("Linear info:", linear_info)
    epoch = 50
    if set_up in [10, 11]:
        print("no need for hyper")
    elif head: # possibly head and lora
        print("using hypernetwork, with one head per layer. (all conv layers and one linear layers)")
        hyper = HyperNetwork_Head(shell, shared_layer= shared_layer,
                                  lora=lora, inter=inter_dim, shared_choice=shared_choice)
    elif lora:
        hyper = LowRankHyper(shell, intermediate_dim=inter_dim,
                             hyper_choice=hyper_choice).to(device)
    else:
        hyper = HyperNetwork_Custom(shell, choice=hyper_choice).to(device)

    learning_rate = l_r

    if set_up == 0:
        user_input = input("Please choose full network mode: batch, one or average: ")
        full_type = user_input.lower()

    if set_up in [10,11]:
        network = shell.to(device)
    else:
        network = FunctionalFullNetwork(hyper, shell, n, mode = full_type, head=head, reflection=reflection).to(device)

    if pretrained:
        print("loading the pretrained model.")
        shell.load_state_dict(torch.load('saved_pth/cifar/{}_shell.pth'.format(shell_choice)))
        p_loss = param_matching_loss(shell, layers_info, network.permutation_list, head=True)
    else:
        p_loss = None
    print("Epoch {} and learning rate {}.".format(epoch, learning_rate))

    if head:
        if lora:
            if shared_choice == 0:
                hyper_save_location = 'saved_pth/cifar/head{}_inter{}_lora_hyp{}_{}_{}.pth'.format(shared_layer, inter_dim, shell_choice,
                                                                                        full_type, n)
                full_save_location = 'saved_pth/cifar/head{}_inter{}_lora_full{}_{}_{}.pth'.format(shared_layer, inter_dim, shell_choice,
                                                                                    full_type, n)
            else:
                hyper_save_location = 'saved_pth/cifar/head{}_choice{}_inter{}_lora_hyp{}_{}_{}.pth'.format(shared_layer, shared_choice,
                                                                                                     inter_dim,
                                                                                                     shell_choice,
                                                                                                     full_type, n)
                full_save_location = 'saved_pth/cifar/head{}_choice{}_inter{}_lora_full{}_{}_{}.pth'.format(shared_layer, shared_choice,
                                                                                                     inter_dim,
                                                                                                     shell_choice,
                                                                                                     full_type, n)
        else:
            hyper_save_location = 'saved_pth/cifar/head{}_hyp{}_{}_{}.pth'.format(shared_layer,shell_choice, full_type, n)
            full_save_location = 'saved_pth/cifar/head{}_full{}_{}_{}.pth'.format(shared_layer, shell_choice, full_type, n)
    elif lora:
        print("The chosen dim:({}x{}) ({}x{}) ".format(hyper.d, inter_dim, inter_dim, hyper.e))
        hyper_save_location = 'saved_pth/cifar/hyp{}_{}_{}_low{}_{}_{}ppg.pth'.format(shell_choice, full_type, n, hyper.d,inter_dim,hyper_choice)
        full_save_location = 'saved_pth/cifar/full{}_{}_{}_low{}_{}_{}ppg.pth'.format(shell_choice, full_type, n, hyper.d,inter_dim,hyper_choice)
    else:
        if hyper_choice == 'resnet':
            hyper_save_location = 'saved_pth/cifar/hyp{}_{}_{}_hardppg.pth'.format(shell_choice, full_type, n)
            full_save_location = 'saved_pth/cifar/full{}_{}_{}_hardppg.pth'.format(shell_choice, full_type, n)
        elif hyper_choice == 'self':
            hyper_save_location = 'saved_pth/cifar/hyp{}_{}_{}_selfppg.pth'.format(shell_choice, full_type, n)
            full_save_location = 'saved_pth/cifar/full{}_{}_{}_selfppg.pth'.format(shell_choice, full_type, n)
        else:
            hyper_save_location = 'saved_pth/cifar/hyp{}_{}_{}.pth'.format(shell_choice, full_type, n)
            full_save_location = 'saved_pth/cifar/full{}_{}_{}.pth'.format(shell_choice, full_type, n)

    if reflection:
        hyper_save_location = hyper_save_location.replace('.pth', '_refl.pth')
        full_save_location = full_save_location.replace('.pth', '_refl.pth')

    if data_choice:
        location = 'saved_pth/cifar100/'
        # check if location exist
        if not os.path.exists(location):
            os.makedirs(location)
        hyper_save_location = hyper_save_location.replace('cifar', 'cifar100')
        full_save_location = full_save_location.replace('cifar', 'cifar100')
    else:
        hyper_save_location = hyper_save_location.replace('cifar', 'cifar10')
        full_save_location = full_save_location.replace('cifar', 'cifar10')

    if load_previous:
        print("Loading previous hyper network.")
        network.load_state_dict(torch.load(full_save_location))

    print("Saving location:", full_save_location)
    # Load the CIFAR-10 training dataset
    if data_choice:
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)

        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch, shuffle=False)
    else:
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
    # Load the CIFAR-10 test dataset
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch, shuffle=False)


    if visualize_filters:
        data_name = 'cifar100' if data_choice else 'cifar10'
        print("Visualizing filters.")
        while True:
            choice = int(input("Choose a number of input to visualize: "))
            root = 'visualize/{}/{}/'.format(data_name, choice)
            if not os.path.exists(root):
                os.makedirs(root)
            image = train_dataset[choice][0].unsqueeze(0).to(device)
            print("image shape", image.shape)
            visualize(layers, layers_info, image, network, root)
            print("choose a different input image?: (y/n)")
            if input().lower() == "n":
                break
        import sys
        sys.exit()
        return

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    # print("Started training.")
    current_time = datetime.datetime.now().strftime("%m-%d %H:%M:%S")
    print("Starting Time:", current_time)
    bar = tqdm.tqdm(np.arange(epoch))
    for counter in bar:
        learning_rate = learning_rate * (1-weight_decay)
        train_loss = []
        for _, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = network(images, True)
            loss = criterion(outputs, labels)
            if pretrained:
                loss += p_loss(network.hypernetwork, images)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        if set_up in [10, 11]:
            bar.set_description(
                "Epoch".format(counter))
        else:
            avg_loss = np.mean(train_loss)
            # print(counter, "th training finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
            # Test the network
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    images = images.to(device)
                    outputs = network(images, False)
                    _, predicted = torch.max(outputs, -1)
                    predicted = predicted.detach().cpu()
                    total += labels.size(0)
                    correct += (predicted == labels.data).sum()
            accuracy = 100 * correct / total
            print()
            # print(counter, "th test finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
            bar.set_description(
                "Loss: {:.4f}, Test Acc: {:.2f}%".format(avg_loss, accuracy))
            torch.save(network.state_dict(), full_save_location)
            torch.save(network.hypernetwork.state_dict(), hyper_save_location)
            print("saved")
        # do it again
    return learning_rate

