import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as ttf
import math
from utils.math_tools import group_permutation_element, find_closest_divisor, goal_linear_dim
from mnist.shared_mnists import *
from utils.model_tools import ShellParser, hyper_goal_output, equi_loader_general


class ShellNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 64, 3)
        self.relu0 = nn.ReLU()
        self.conv1 = nn.Conv2d(64, 32, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 10, 2)
        self.final_dim = 23
        self.linear0 = nn.Linear(10 * 23 * 23, 500)
        self.relu3 = nn.ReLU()
        self.linear1 = nn.Linear(500, 100)
        self.relu4 = nn.ReLU()
        self.linear2 = nn.Linear(100, 10)

    def forward(self, x):
        cnn_part = nn.Sequential(self.conv0, self.relu0, self.conv1, self.relu1, self.conv2)
        x = cnn_part(x)
        x = x.view(x.size(0), -1)
        mlp_part = nn.Sequential(self.linear0, self.relu3, self.linear1, self.relu4, self.linear2)
        x = mlp_part(x)
        print(x.shape)
        return x


class Simple_Shell(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 16, 2, bias=False)
        self.relu0 = nn.ReLU()
        self.conv1 = nn.Conv2d(16, 8, 2)
        self.linear0 = nn.Linear(8 * 26 * 26, 100)
        self.relu1 = nn.ReLU()
        self.linear1 = nn.Linear(100, 10)

    def forward(self, x):
        cnn_part = nn.Sequential(self.conv0, self.relu0, self.conv1)
        x = cnn_part(x)
        x = x.view(x.size(0), -1)
        mlp_part = nn.Sequential(self.linear0, self.relu1, self.linear1)
        x = mlp_part(x)
        return x

class GCNNShell_mnist(nn.Module):
    def __init__(self, out_channels = 10, norm = False, version = False):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 20, 3)
        self.relu0 = nn.ReLU()
        self.conv1 = nn.Conv2d(20, 20, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(20, 20, 3)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(20, 20, 3)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(20, 20, 3)
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(20, 20, 3)
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(20, 20, 3)
        self.relu6 = nn.ReLU()
        self.conv7 = nn.Conv2d(20, 20, 4)
        self.relu7 = nn.ReLU()
        self.norm = norm
        if version:
            self.linear0 = nn.Linear(20 * 11 * 11, 100)
            self.relu_linear0 = nn.ReLU()
            self.linear1 = nn.Linear(100, 10)
        elif norm:
            self.pool  = nn.AdaptiveAvgPool2d(1)
            self.linear0 = nn.Linear(20, out_channels)
        else:
            self.linear0 = nn.Linear(20 * 11 * 11, out_channels)

    def forward(self, x):
        cnn_part = nn.Sequential(self.conv0, self.relu0, self.conv1, self.relu1, self.conv2, self.relu2, self.conv3, self.relu3, self.conv4, self.relu4, self.conv5, self.relu5, self.conv6, self.relu6, self.conv7, self.relu7)
        x = cnn_part(x)
        x = x.view(x.size(0), -1)
        if self.norm:
            x = self.pool(x)
        # mlp_part = nn.Sequential(self.linear0, self.relu_linear0, self.linear1)
        mlp_part = nn.Sequential(self.linear0)
        x = mlp_part(x)

        return x


class HyperNetwork(nn.Module):
    """
    When initializing the HyperNetwork, we need to specify the input_size. and target_network.
    Default Target Network is the above Shell Network.
    """

    def __init__(self, input_size, target_network):
        super(HyperNetwork, self).__init__()
        self.input_size = input_size
        # self.target_network = target_network
        parameter_count, names, info = ShellParser(target_network)
        self.hyper_output_size = sum(hyper_goal_output(info))
        print(f"output size is {self.hyper_output_size} .")
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=4)
        self.fc1 = nn.Linear(32 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, self.hyper_output_size)

    def forward(self, x):
        network_cnn = nn.Sequential(self.conv1, self.relu1, self.conv2, self.relu1, self.conv3, self.relu1)
        network_linear = nn.Sequential(self.fc1, self.relu1, self.fc2)
        x = network_cnn(x)
        x = x.view(x.size(0), -1)
        x = network_linear(x)
        return x

class Embed_HyperNetwork_mnist(nn.Module):
    def __init__(self, input_size, target_network, embedding_dim: int, chunks: int = 0):
        super(Embed_HyperNetwork_mnist, self).__init__()
        self.input_size = input_size
        # self.target_network = target_network
        parameter_count, names, info = ShellParser(target_network)
        self.hyper_output_size = sum(hyper_goal_output(info))
        if chunks == 0:
            chunks = find_closest_divisor(input_size, self.hyper_output_size, embedding_dim)
        self.chunks = chunks
        self.chunk_dim = math.ceil(self.hyper_output_size / chunks)
        # self.weight_generator = nn.Linear(embedding_dim, self.chunk_dim)
        self.embedding_module = nn.Embedding(chunks, embedding_dim)
        self.input_embedder = nn.Linear(input_size, chunks)
        # places to add extra parameters:
        # 1. input embedder can be a CNN.
        # 2. weight generator stay linear but can be multi-layered.
        # self.weight_generator = nn.Sequential(nn.Linear(embedding_dim, 10 * self.chunk_dim), nn.Linear(10 * self.chunk_dim, 5*self.chunk_dim), nn.Linear(5*self.chunk_dim, self.chunk_dim))
        self.weight_generator = nn.Linear(embedding_dim, self.chunk_dim)
        self.extra = self.chunks * self.chunk_dim - self.hyper_output_size

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.reshape(-1, self.input_size)
        embedded_input = self.input_embedder(x)
        # shape: [batch_size, chunks]
        embedded_input = embedded_input.unsqueeze(-1)
        arrange_tensor = torch.stack(
            [torch.arange(self.chunks).to(embedded_input.device) for i in range(batch_size)]).to(embedded_input.device)
        embed_part_1 = self.embedding_module(arrange_tensor)
        # shape: [batch_size, chunks, embed_dim]
        embedding = embedded_input * embed_part_1
        partial_params = self.weight_generator(embedding)
        # Input: [batch_size, chunks, embed_dim]
        # Output: [batch_size, chunks, chunk_dim]
        partial_params = partial_params.view(x.shape[0], -1)
        # final: [batch_size, chunks*chunk_dim]
        if self.extra > 0:
            partial_params = partial_params[:, :self.hyper_output_size]
        return partial_params

class LowRankHyperNetwork(nn.Module):
    def __init__(self, input_size, target_network, matrix_dim=0, intermediate_dim=4):
        super(LowRankHyperNetwork, self).__init__()
        self.intermediate_dim = intermediate_dim
        self.input_size = input_size
        # self.target_network = target_network
        parameter_count, names, info = ShellParser(target_network)
        self.hyper_output_size = sum(hyper_goal_output(info))
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=4)
        self.fc1 = nn.Linear(32 * 3 * 3, 128)
        # the goal is self.hyper_output_size, but we can't use it directly, so we use LoRa to form a dxd matrix.
        if matrix_dim == 0:
            self.d = math.ceil(math.sqrt(self.hyper_output_size))
        else:
            self.d = matrix_dim
        self.e = math.ceil(self.hyper_output_size / self.d)
        self.fc2a = nn.Linear(128, self.d * self.intermediate_dim)
        self.fc2b = nn.Linear(128, self.e * self.intermediate_dim)
        self.extra = self.d * self.e - self.hyper_output_size

    def forward(self,x):
        network_cnn = nn.Sequential(self.conv1, self.relu1, self.conv2, self.relu1, self.conv3, self.relu1)
        x = network_cnn(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu1(x)
        a = self.fc2a(x).reshape(-1, self.d, self.intermediate_dim)
        b = self.fc2b(x).reshape(-1, self.intermediate_dim, self.e)
        result = torch.matmul(a, b).view(-1, self.d * self.e)
        if self.extra > 0:
            result = result[:, :self.hyper_output_size]
        return result

class HyperNetwork_Head(nn.Module):
    def __init__(self, target_network, lora = False, inter = 1, shared_choice = 1):
        ''' in the future, if result looks good, we're going to add lora. '''
        super(HyperNetwork_Head, self).__init__()
        self.cnn_count, self.cnn_details = ShellParser(target_network, general=1)
        self.linear_info = ShellParser(target_network, general=2)
        self.cnn = CustomSharedHyper(shared_choice)
        heads = []
        self.out_dims = []
        # output is in shape of 32*4*16
        for i in range(len(self.cnn_details)):
            out_dim = (self.cnn_details[i][0]*self.cnn_details[i][1] * (math.ceil(self.cnn_details[i][2]/2))**2)
            self.out_dims.append(out_dim)
            if self.cnn_details[i][4]:
                out_dim = out_dim + self.cnn_details[i][0]
            heads.append(individual_head(self.cnn.cnn_dim, out_dim,
                                          lora, intermediate_dim=inter))
            # heads.append(nn.Linear(self.cnn.cnn_dim, out_dim))
        square = int(math.sqrt(self.linear_info[1] / self.linear_info[2]))
        goal_dim = goal_linear_dim(square)
        linear_dim = self.linear_info[0] * goal_dim * self.linear_info[2]
        self.need_dims = goal_dim*self.linear_info[2]
        if self.linear_info[3]:
            linear_dim = linear_dim + self.linear_info[0]
        heads.append(individual_head(self.cnn.cnn_dim, linear_dim, lora))
        self.heads = nn.ModuleList(heads)
        print(len(heads), "# heads")
        # only one linear layer. to be implemented: share the other linear.
    def forward(self, x):
        x = self.cnn(x)
        result = []
        # x = x.view(x.size(0), -1)
        for i, head in enumerate(self.heads):
            if i == len(self.heads) - 1:
                # first linear = last head
                out = head(x)
                length = self.linear_info[0]
                width = self.need_dims
                bias = self.linear_info[3]
                out_weight = out[:, :length*width]
                result.append(out_weight.view(out.size(0), length, width))
                if bias:
                    result.append(out[:, length*width:])
                else:
                    result.append(None)
            else:
                out = head(x)
                filter_length = self.out_dims[i]
                bias = self.cnn_details[i][4]
                shape = (-1, self.cnn_details[i][0], self.cnn_details[i][1], math.ceil(self.cnn_details[i][2]/2),
                         math.ceil(self.cnn_details[i][2]/2))
                result.append(out[:, :filter_length].view(shape))
                if bias:
                    result.append(out[:, filter_length:])
                else:
                    result.append(None)
        return result

class FunctionalFullNetwork(nn.Module):
    """
    Please use this instead of FullNetwork, as this contains grad info from the hypernetwork to the target network.
    """

    def __init__(self, hypernetwork, target_network, n, head = False):
        super(FunctionalFullNetwork, self).__init__()
        self.hypernetwork = hypernetwork
        self.head = head
        self.parameter_count, self.layers_list, self.layers_info = ShellParser(target_network)
        self.goal_list = hyper_goal_output(self.layers_info)
        self.power = n
        average_pool = []
        for i, layer in enumerate(self.layers_list):
            if layer == 'AdaptiveAvgPool2D':
                goal_dim = self.layers_info[i][0]
                average_pool.append(nn.AdaptiveAvgPool2d(goal_dim))

        self.average_pool = nn.ModuleList(average_pool)

        for info in self.layers_info:
            if info == None:
                continue
            elif len(info) == 4:
                a = int(math.sqrt(info[1]/info[2]))
                self.permutation_list = group_permutation_element(a, -1)
                break
        # for the first linear layer, this will be used, but it does take quite some time so put it in init.

    def filter_only(self, x, choice=[0,0]):
        cnn_para, linear_para = equi_loader_general(self.hypernetwork, self.layers_info, x,
                                                    self.permutation_list,self.goal_list, self.power, head=self.head)
        filters_list = []
        filter_bias = []
        linear_list = []
        linear_bias = []
        if choice == -1:
            for i, filters in enumerate(cnn_para):
                if i % 2 == 0:  # filters
                    if filters is not None:
                        filters = filters[0]
                        filters_list.append(filters)
            return filters_list
        else:
            for i, filters in enumerate(cnn_para):
                if i % 2 == 0:  # filters
                    if filters is not None:
                        filters = filters[0]
                        # print("For the {}-th Conv2d layer, the filters shape is {}".format(int(i/2),
                        #                                                             filters.shape))  # should be [batch, channel, height, width]
                    else:
                        # print("For {} layer, filter is none.... wait this is not supposed to happen".format(int(i/2)))
                        raise ValueError("Should not have a None Filter")
                    if choice[1] >= len(filters):
                        filters = filters[-1]
                    else:
                        filters = filters[choice[1]]
                    if choice[0] >= len(filters):
                        filters = filters[-1]
                    else:
                        filters = filters[choice[0]]
                    filters_list.append(filters)
                    # filters_list.append(filters[choice][choice])
                else: # bias
                    if filters is not None:
                        filters = filters[0]
                        # print("For the {}-th layer, the bias shape is {}".format(int((i-1) / 2),
                        #                                                             filters.shape))  # should be [batch, channel]
                        filter_bias.append(filters[choice[1]])
                    else:
                        # print("For the {}-th layer, the bias is None".format(int((i - 1) / 2) ))
                        filter_bias.append(None)
            linear_list.append(linear_para[0][0][choice[1]])
            if linear_para[1] is None:
                linear_bias.append(None)
            else:
                linear_bias.append(linear_para[1][0][choice[0]])
            return filters_list, filter_bias, linear_list, linear_bias

    def forward(self, x):
        """
        generate one forth of the params with hypernetwork,
        pass the params to the equi-loader to get the full network params,
        then load the full network.
        :param x:
        :return:
        """
        result = []
        # cnn_para, mlp_para = equi_loader_general(self.hypernetwork, self.layers_info, x, self.permutation_list,self.goal_list, self.power, head=self.head)
        cnn_para, mlp_para = equi_loader_general(self.hypernetwork, self.layers_info, x, self.permutation_list,
                                                        self.goal_list, self.power, self.head)
        # if self.power ==0:
        # cnn_para, mlp_para = equi_loader_90(self.hypernetwork, self.layers_info, x, self.permutation_list,self.goal_list)
        # else:
        # cnn_para, mlp_para = equi_loader_general(self.hypernetwork, self.layers_info, x, self.permutation_list,self.goal_list, self.power)
        for counter in range(x.shape[0]):
            result.append(self.target_forward(cnn_para, mlp_para, x[counter].view(1, 1, 28, 28), counter))
        return torch.cat(result, dim=0)

    def target_forward(self, cnn_para, mlp_para, x, counter):
        # Parse the parameters into the CNN layers and mlp layers.
        need_flatten = True
        cnn_counter = 0
        linear_counter = 0
        adaptive_pool_counter = 0

        for i, layer_type in enumerate(self.layers_list):
            if layer_type == 'Conv2D':
                if cnn_para[cnn_counter + 1] is None:
                    x = F.conv2d(x, cnn_para[cnn_counter][counter], bias=None)
                else:
                    x = F.conv2d(x, cnn_para[cnn_counter][counter], bias=cnn_para[cnn_counter + 1][counter])
                cnn_counter += 2
            # elif layer_type == 'MaxPooling2D':
            #     x = F.max_pool2d(x, 2)

            elif layer_type == 'ReLU':
                x = F.relu(x)
            elif layer_type == 'Flatten':
                x = x.view(x.size(0), -1)
                need_flatten = False
            elif layer_type == 'Linear':
                if need_flatten:
                    x = x.view(x.size(0), -1)
                    need_flatten = False
                x = F.linear(x, mlp_para[linear_counter][counter], bias=mlp_para[linear_counter + 1][counter])
                linear_counter += 2
            elif layer_type == 'AdaptiveAvgPool2D':
                x = self.average_pool[adaptive_pool_counter](x)
                adaptive_pool_counter += 1
            else:
                pass
        x = x.view(x.size(0), -1)
        return x


if __name__ == '__main__':
    pr = 0
    shell = Simple_Shell()
    a = ShellParser(shell)
    individual_count, layers, layers_info = a
    total_count = sum(individual_count)
    input_size = 28 * 28
    example_MNIST = torch.rand(2, 1, 28, 28)
    e_90 = torch.rot90(example_MNIST, 1, dims=[2, 3])
    e_45 = ttf.rotate(example_MNIST, 45)
    e_135 = ttf.rotate(example_MNIST, 135)
    hyper_network = HyperNetwork(input_size, shell)
    embed_hyper_network = Embed_HyperNetwork_mnist(input_size, shell, 4, 0)
    low_rank_hyper_network = LowRankHyperNetwork(input_size, shell,  0, 4)
    total_params = sum(p.numel() for p in low_rank_hyper_network.parameters())
    print(f"Total number of parameters: {total_params}")
    n = 1
    example_params = hyper_network(example_MNIST)
    print("Example params shape is ", example_params.shape)
    example_params = embed_hyper_network(example_MNIST)
    print("Example params shape is ", example_params.shape)
    example_params = low_rank_hyper_network(example_MNIST)
    print("Example params shape is ", example_params.shape)

    goal_list = hyper_goal_output(layers_info)

    cnn_para_1, mlp_para_1 = equi_loader_general(low_rank_hyper_network, layers_info, example_MNIST, group_permutation_element(26, -1), goal_list, n)

    print("Using Equi-Loader general: CNN para and Linear Para have length for n={}, ".format(n), len(cnn_para_1), len(mlp_para_1))
    functional_full_network = FunctionalFullNetwork(low_rank_hyper_network, shell, n)
    final_output = functional_full_network(example_MNIST)
    print("Final output shape is ", final_output.shape)
    result, b = torch.max(final_output, 1)
    print("Predicted output is ", b)
    # output = shell(example_MNIST)
    # hyper_output = hyper_network(example_MNIST)
    # functional_full_network = FunctionalFullNetwork(hyper_network, shell)
    # final_output = functional_full_network(example_MNIST)
    # print("Final output shape is ", final_output.shape, " vs ", output.shape)
    #
    # example_params = hyper_output[0]
    # cnn_para, filters_info = ShellParser(shell, 1)
    # if pr == 1:
    #     print(individual_count, " parameter count.")
    #     print(hyper_goal_output(layers_info), " needed output amount")
    #     print(
    #         "Three outputs are in shape: \n{} for shell network, "
    #         "\n{} for hyper network, and "
    #         "\n{} for functional full network."
    #         .format(output.shape, hyper_output.shape, final_output))
    #     print(shell.conv0.weight.shape, " the weight should be in shape")
    #     print("From  shell parser, this is what we can get")
    #     print(individual_count, "individual_count")
    #     print(sum(individual_count))
    #     print()
    #     print(layers)
    #     print(layers_info)
    #     print()
    #     print("when setting generals to 1, we can get the cnn info")
    #     print(cnn_para, "cnn_info 1")
    #     print(filters_info, "cnn info 2")
    #     print()
    #     print("when setting generals to 2, we can get the mlp info")
    #     print(ShellParser(shell, 2), "mlp_info")
