import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision.models
from cifar.shared_parts import *
from utils.math_tools import group_permutation_element, goal_linear_dim
from utils.model_tools import ShellParser, hyper_goal_output, equi_loader_general, ShellParser_Layers, param_matching_loss

def Rot_Similarity(x):
    x = x.reshape(-1, x.shape[-2], x.shape[-1])
    sim = []
    for i in range(x.shape[0]):
        sim.append(target_rot_similarity(x[i]))
    # take the mean of the similarity
    return torch.mean(torch.stack(sim))

def target_rot_similarity(x):
    return 0


class BasicBlock_Fixed(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock_Fixed, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection
        self.shortcut = False
        if stride != 1 or in_channels != out_channels:
            self.shortcut = True
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.shortcut:
            out += self.downsample(residual)
        out = self.relu(out)
        return out

class BaseBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BaseBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = False
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut=True
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.shortcut:
            out += self.downsample(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = False
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = True
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.shortcut:
            out += self.downsample(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18_cifar10(num_classes = 10):
    return ResNet(BaseBlock, [2, 2, 2, 2], num_classes=num_classes)

# Define Modified ResNet-18
class ShellNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super(ShellNetwork, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride):
        layers = []
        layers.append(BasicBlock_Fixed(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for _ in range(1, num_blocks):
            layers.append(BasicBlock_Fixed(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out

class toy_shell(nn.Module):
    def __init__(self):
        super(toy_shell, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(4, 8, 3, stride=2)
        self.linear = nn.Linear(1568, 32)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

class Simple_Shell(nn.Module):
    def __init__(self, num_classes=10, placeholder=False):
        super(Simple_Shell, self).__init__()
        self.conv1a = nn.Conv2d(3, 128, kernel_size=3, padding=1)
        self.lrelu1 = nn.LeakyReLU(negative_slope=0.1)
        self.conv1b = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.lrelu2 = nn.LeakyReLU(negative_slope=0.1)
        self.conv1c = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.lrelu3 = nn.LeakyReLU(negative_slope=0.1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop1 = nn.Dropout(p=0.5)
        self.conv2a = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.lrelu4 = nn.LeakyReLU(negative_slope=0.1)
        self.conv2b = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.lrelu5 = nn.LeakyReLU(negative_slope=0.1)
        self.conv2c = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.lrelu6 = nn.LeakyReLU(negative_slope=0.1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop2 = nn.Dropout(p=0.5)
        self.conv3a = nn.Conv2d(256, 512, kernel_size=3)
        self.lrelu7 = nn.LeakyReLU(negative_slope=0.1)
        self.conv3b = nn.Conv2d(512, 256, kernel_size=1)
        self.lrelu8 = nn.LeakyReLU(negative_slope=0.1)
        self.conv3c = nn.Conv2d(256, 128, kernel_size=1)
        self.lrelu9 = nn.LeakyReLU(negative_slope=0.1)
        self.globalpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, placeholder=False):
        x = self.conv1a(x)
        x = self.lrelu1(x)
        x = self.conv1b(x)
        x = self.lrelu2(x)
        x = self.conv1c(x)
        x = self.lrelu3(x)
        x = self.pool1(x)
        x = self.drop1(x)
        x = self.conv2a(x)
        x = self.lrelu4(x)
        x = self.conv2b(x)
        x = self.lrelu5(x)
        x = self.conv2c(x)
        x = self.lrelu6(x)
        x = self.pool2(x)
        x = self.drop2(x)
        x = self.conv3a(x)
        x = self.lrelu7(x)
        x = self.conv3b(x)
        x = self.lrelu8(x)
        x = self.conv3c(x)
        x = self.lrelu9(x)
        x = self.globalpool(x)
        x = x.view(-1, 128)
        x = self.fc(x)
        # x = self.softmax(x)
        return x

class adjusted(nn.Module):
    def __init__(self, num_classes=10):
        super(adjusted, self).__init__()
        self.conv1a = nn.Conv2d(3, 128, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv1b = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv1c = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop1 = nn.Dropout(p=0.5)
        self.conv2a = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv2b = nn.Conv2d(256, 256, kernel_size=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop2 = nn.Dropout(p=0.5)
        self.conv3 = nn.Conv2d(256, 128, kernel_size=2)
        self.globalpool = nn.AdaptiveAvgPool2d(3)
        self.fc = nn.Linear(128*3*3, num_classes)

    def forward(self, x):
        x = self.conv1a(x)
        x = self.relu(x)
        x = self.conv1b(x)
        x = self.relu(x)
        x = self.conv1c(x)
        x = self.relu(x)
        x = self.pool1(x)
        x = self.drop1(x)
        x = self.conv2a(x)
        x = self.relu(x)
        x = self.conv2b(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = self.drop2(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.globalpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

class GCNNShell_cifar(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 32, 3)
        self.relu0 = nn.ReLU()
        self.conv1 = nn.Conv2d(32, 32, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 32, 3)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(32, 32, 3)
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(32, 32, 3)
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(32, 32, 3)
        self.relu6 = nn.ReLU()
        self.conv7 = nn.Conv2d(32, 32, 4)
        self.relu7 = nn.ReLU()
        self.linear0 = nn.Linear(32 * 15 * 15, 300)
        self.relu_linear0 = nn.ReLU()
        self.linear1 = nn.Linear(300, 60)
        self.relu_linear1 = nn.ReLU()
        self.linear2 = nn.Linear(60, 10)

    def forward(self, x):
        cnn_part = nn.Sequential(self.conv0, self.relu0, self.conv1, self.relu1, self.conv2, self.relu2, self.conv3,
                                 self.relu3, self.conv4, self.relu4, self.conv5, self.relu5, self.conv6, self.relu6,
                                 self.conv7, self.relu7)
        x = cnn_part(x)
        x = x.view(x.size(0), -1)
        mlp_part = nn.Sequential(self.linear0, self.relu_linear0, self.linear1, self.relu_linear1, self.linear2)
        x = mlp_part(x)
        return x


class HyperNetwork_cifar(nn.Module):
    """
    When initializing the HyperNetwork, we need to specify the input_size. and target_network.
    Default Target Network is the above Shell Network.
    """

    def __init__(self, target_network):
        super(HyperNetwork_cifar, self).__init__()
        # self.target_network = target_network
        parameter_count, names, info = ShellParser(target_network)
        self.hyper_output_size = sum(hyper_goal_output(info))
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 32*2, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(32*2, 32*4, kernel_size=4)
        self.fc1 = nn.Linear(32*2 * 4 * 4, 128*8)
        self.fc2 = nn.Linear(128*8, self.hyper_output_size)
        # self.fc = nn.Linear(32*4 * 4 * 4, self.hyper_output_size)

    def forward(self, x):
        network_cnn = nn.Sequential(self.conv1, self.relu1, self.conv2, self.relu1, self.conv3, self.relu1)
        network_linear = nn.Sequential(self.fc1, self.relu1, self.fc2)
        # network_linear = nn.Sequential(self.fc)
        x = network_cnn(x)
        x = x.view(x.size(0), -1)
        x = network_linear(x)
        return x

    # def __init__(self, hidden_dim=10):
    #     super(Default_Hyper, self).__init__()
    #     self.hidden_dim = hidden_dim
    #     self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2)
    #     self.relu1 = nn.ReLU()
    #     self.conv2 = nn.Conv2d(32, 32 * 2, kernel_size=3, stride=2)
    #     self.conv3 = nn.Conv2d(32 * 2, 32 * 4, kernel_size=4)
    #     self.cnn_dim = 32*4*16
    #     if hidden_dim != 0:
    #         self.fc1 = nn.Linear(self.cnn_dim, hidden_dim)
    #
    # def forward(self, x):
    #     network_cnn = nn.Sequential(self.conv1, self.relu1, self.conv2, self.relu1, self.conv3, self.relu1)
    #     x = network_cnn(x)
    #     if self.hidden_dim != 0:
    #         network_linear = nn.Sequential(self.fc1, self.relu1)
    #         x = x.view(x.size(0), -1)
    #         x = network_linear(x)
    #     return x


class Default_Hyper(nn.Module):
    def __init__(self, hidden_dim=10):
        super(Default_Hyper, self).__init__()
        self.hidden_dim = 0
        self.cnn = option_4()


    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        return x


class HyperNetwork_Custom(nn.Module):
    """
    More complicated parameter generator...
    """

    def __init__(self, target_network, choice='resnet'):
        super(HyperNetwork_Custom, self).__init__()
        # self.target_network = target_network
        parameter_count, names, info = ShellParser(target_network)
        hyper_output_size = sum(hyper_goal_output(info))
        if choice =='resnet':
            self.ppg = ShellNetwork(num_classes=hyper_output_size)
        elif choice =='itself':
            self.ppg = adjusted(num_classes=hyper_output_size)
        elif choice =='simple':
            hidden_dim = 64*4
            print("The output of cnn part is", hidden_dim)
            self.ppg = nn.Sequential(Default_Hyper(),
                                     nn.Linear(hidden_dim, hyper_output_size))
    def forward(self, x):
        x = self.ppg(x)
        return x


class LowRankHyper(nn.Module):
    def __init__(self, target_network, matrix_dim=0, intermediate_dim=4):
        super(LowRankHyper, self).__init__()
        parameter_count, names, info = ShellParser(target_network)
        self.hyper_output_size = sum(hyper_goal_output(info))
        # if hyper_choice =='simple':
        self.hidden_dim = 64*4
        self.ppg = Default_Hyper(hidden_dim=0)

        if matrix_dim == 0:
            self.d = math.ceil(math.sqrt(self.hyper_output_size))
        else:
            self.d = matrix_dim
        self.e = math.ceil(self.hyper_output_size / self.d)
        self.fc2a = nn.Linear(self.hidden_dim, self.d * intermediate_dim)
        self.fc2b = nn.Linear(self.hidden_dim, self.e * intermediate_dim)
        self.extra = self.d * self.e - self.hyper_output_size

    def forward(self, x):
        x = self.ppg(x)
        a = self.fc2a(x).reshape(x.shape[0], self.d, -1)
        b = self.fc2b(x).reshape(x.shape[0], -1, self.e)
        result = torch.matmul(a, b).view(-1, self.d * self.e)
        if self.extra > 0:
            result = result[:, :self.hyper_output_size]
        return result


class CustomSharedHyper(nn.Module):
    def __init__(self, shared_choice):
        super(CustomSharedHyper, self).__init__()
        if shared_choice == 2:
            self.net = largest()
            self.cnn_dim =32*4*4

        elif shared_choice == 1:
            self.net = medium()
            self.cnn_dim =32*4*4

        elif shared_choice ==3:
            self.net = small()
            self.cnn_dim =32*3*3

        elif shared_choice ==4:
            self.net = option_4()
            self.cnn_dim = 64 *2*2

        elif shared_choice == 5:
            self.net = option_5()
            self.cnn_dim = 512

        elif shared_choice == 6:
            self.net = option_6()
            self.cnn_dim = 512*2*2

        elif shared_choice == 7:
            self.net = option_7()
            self.cnn_dim = 256

        elif shared_choice == 8:
            self.net = option_8()
            self.cnn_dim = self.net.out_dim

        elif shared_choice == 9:
            self.net = option_9()
            self.cnn_dim = self.net.out_dim
        else:
            raise NotImplementedError("choose shared hyper")

    def forward(self,x):
        x = self.net(x)
        return x



class individual_head(nn.Module):
    def __init__(self, previous_dim, final_dim, num_layer=3, lora=False, intermediate_dim=1):
        super(individual_head, self).__init__()
        self.num_layer = num_layer
        self.lora = lora
        if lora:
            self.d = math.ceil(math.sqrt(final_dim))
            self.e = math.ceil(final_dim / self.d)
            self.fca = nn.Linear(previous_dim, self.d * intermediate_dim)
            self.fcb = nn.Linear(previous_dim, self.e * intermediate_dim)
            self.extra = self.d * self.e - final_dim
        elif num_layer == 3:
            self.fc1 = nn.Linear(previous_dim, final_dim)
        elif num_layer == 2:
            self.conv1 = nn.Conv2d(16, 32, kernel_size=4)
            self.fc1 = nn.Linear(previous_dim, final_dim)


    def forward(self, x):
        if self.lora:
            x = x.view(x.size(0), -1)
            x1 = self.fca(x).reshape(x.shape[0], self.d, -1)
            x2 = self.fcb(x).reshape(x.shape[0], -1, self.e)
            x = torch.matmul(x1, x2).view(-1, self.d * self.e)
            if self.extra > 0:
                x = x[:, :(-1)*self.extra]
        else:
            if self.num_layer == 2:
                x = self.conv1(x)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
        return x


class HyperNetwork_Head(nn.Module):
    def __init__(self, target_network, shared_layer = 3,
                 lora = False, inter = 1, shared_choice = 0):
        ''' in the future, if result looks good, we're going to add lora. '''
        super(HyperNetwork_Head, self).__init__()
        self.cnn_count, self.cnn_details = ShellParser(target_network, general=1)
        self.linear_info = ShellParser(target_network, general=2)
        self.shared_layer = shared_layer
        if shared_choice == 0:
            self.cnn = SharedHyper(shared_layer)
        else:
            self.cnn = CustomSharedHyper(shared_choice)
        heads = []
        self.out_dims = []

        # output is in shape of 32*4*16
        for i in range(len(self.cnn_details)):
            out_dim = (self.cnn_details[i][0]*self.cnn_details[i][1] * (math.ceil(self.cnn_details[i][2]/2))**2)
            self.out_dims.append(out_dim)
            if self.cnn_details[i][4]:
                out_dim = out_dim + self.cnn_details[i][0]
            heads.append(individual_head(self.cnn.cnn_dim, out_dim,
                                         self.shared_layer, lora, intermediate_dim=inter))
            # heads.append(nn.Linear(self.cnn.cnn_dim, out_dim))
        square = int(math.sqrt(self.linear_info[1] / self.linear_info[2]))
        goal_dim = goal_linear_dim(square)
        linear_dim = self.linear_info[0] * goal_dim * self.linear_info[2]
        self.need_dims = goal_dim*self.linear_info[2]
        if self.linear_info[3]:
            linear_dim = linear_dim + self.linear_info[0]
        heads.append(individual_head(self.cnn.cnn_dim, linear_dim, self.shared_layer, lora))
        self.heads = nn.ModuleList(heads)
        # only one linear layer. to be implemented: share the other linear.

    def forward(self, x):
        x = self.cnn(x)
        result = []
        # x = x.view(x.size(0), -1)
        for i, head in enumerate(self.heads):
            if i == len(self.heads) - 1:
                # first linear = last head
                out = head(x)
                length = self.linear_info[0]
                width = self.need_dims
                bias = self.linear_info[3]
                out_weight = out[:, :length*width]
                result.append(out_weight.view(out.size(0), length, width))
                if bias:
                    result.append(out[:, length*width:])
                else:
                    result.append(None)
            else:
                out = head(x)
                filter_length = self.out_dims[i]
                bias = self.cnn_details[i][4]
                shape = (-1, self.cnn_details[i][0], self.cnn_details[i][1], math.ceil(self.cnn_details[i][2]/2),
                         math.ceil(self.cnn_details[i][2]/2))
                result.append(out[:, :filter_length].view(shape))
                if bias:
                    result.append(out[:, filter_length:])
                else:
                    result.append(None)
        return result


class FunctionalFullNetwork(nn.Module):
    """
    Please use this instead of FullNetwork, as this contains grad info from the hypernetwork to the target network.
    """

    def __init__(self, hypernetwork, target_network, n, mode='average', head = False,
                 reflection=False, extra_loss=False, extra_invariant= False):
        super(FunctionalFullNetwork, self).__init__()
        self.hypernetwork = hypernetwork
        self.extra_loss = extra_loss
        self.parameter_count, self.layers_list, self.layers_info = ShellParser(target_network)
        self.goal_list = hyper_goal_output(self.layers_info)
        self.power = n
        self.mode = mode
        self.head = head
        self.reflection = reflection
        if extra_invariant is True:
            self.extra_invariant = extra_invariant
        if reflection:
            print("we add reflection to embrace group D from Z")
        batch_norm = []
        average_pool = []
        max_pool = []
        leaky_relu = []
        for info in self.layers_info:
            if info == None:
                continue
            elif len(info) == 4:
                '''this should be the linear layer right after CNN. should only be one.'''
                a = int(math.sqrt(info[1] / info[2]))
                self.permutation_list = group_permutation_element(a, -1)
                break

        for i, layer in enumerate(self.layers_list):
            if layer == 'AdaptiveAvgPool2D':
                goal_dim = self.layers_info[i][0]
                average_pool.append(nn.AdaptiveAvgPool2d(goal_dim))
            elif layer == 'BatchNorm2D':
                n_features = self.layers_info[i]
                batch_norm.append(nn.BatchNorm2d(num_features=n_features))
            elif layer == 'MaxPool2D':
                kernel_size = self.layers_info[i][0]
                stride = self.layers_info[i][1][0]
                pad = self.layers_info[i][1][1]
                max_pool.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=pad))
            elif layer == 'LeakyReLU':
                negative_slope = self.layers_info[i][0]
                leaky_relu.append(nn.LeakyReLU(negative_slope=negative_slope))
        self.batch_norm = nn.ModuleList(batch_norm)
        self.average_pool = nn.ModuleList(average_pool)
        self.max_pool = nn.ModuleList(max_pool)
        if len(leaky_relu) != 0:
            self.leaky_relu = nn.ModuleList(leaky_relu)

        # for the first linear layer, this will be used, but it does take quite some time so put it in init.

    def filter_only(self, x, choice=0):
        """used to get specific filters for plotting purpose"""
        if self.mode == 'one':
            ### batch of input,  one para one hyper. save space
            cnn_para, linear_para = equi_loader_general(self.hypernetwork, self.layers_info,
                                                        torch.mean(x, dim=0, keepdim=True), self.permutation_list,
                                                        self.goal_list, self.power, self.head,
                                                        reflection=self.reflection)

        else:
            cnn_para, linear_para = equi_loader_general(self.hypernetwork, self.layers_info, x, self.permutation_list,
                                                        self.goal_list, self.power, self.head, self.reflection)
        filters_list = []
        filter_bias = []
        linear_list = []
        linear_bias = []
        for i, filters in enumerate(cnn_para):
            if i % 2 == 0:  # filters
                if filters is not None:
                    filters = filters[0]
                else:
                    raise ValueError("Should not have a None Filter")
                if choice >= len(filters):
                    filters = filters[-1]
                else:
                    filters = filters[choice]
                if choice >= len(filters):
                    filters = filters[-1]
                else:
                    filters = filters[choice]
                filters_list.append(filters)
                # filters_list.append(filters[choice][choice])
            else: # bias
                if filters is not None:
                    filters = filters[0]
                    filter_bias.append(filters[choice][choice])
                else:
                    filter_bias.append(None)
        linear_list.append(linear_para[0][0][choice])
        if linear_para[1] is None:
            linear_bias.append(None)
        else:
            linear_bias.append(linear_para[1][0][choice])
        return filters_list, filter_bias, linear_list, linear_bias


    def forward(self, x, train=True):
        """
        generate one forth of the params with hypernetwork,
        pass the params to the equi-loader to get the full network params,
        then load the full network.
        """
        need_flatten = True
        skip_1 = False
        skip_2 = False
        cnn_counter, linear_counter, max_pool_counter, batch_norm_counter, adaptive_pool_counter, leaky_counter = 0, 0, 0, 0, 0, 0

        if self.mode == 'one':
            ### batch of input,  one para one hyper. save space
            cnn_para, linear_para = equi_loader_general(self.hypernetwork, self.layers_info,
                                                        torch.mean(x, dim=0, keepdim=True), self.permutation_list,
                                                        self.goal_list, self.power, self.head, reflection=self.reflection)

        else:
            cnn_para, linear_para = equi_loader_general(self.hypernetwork, self.layers_info, x, self.permutation_list,
                                                        self.goal_list, self.power, self.head, self.reflection)

        if self.extra_loss:
            loss_list = []
            for i in range(len(cnn_para)):
                if i % 2 == 0:  # filters
                    loss_list.append(Rot_Similarity(cnn_para[i]))
            rotation_loss = torch.mean(torch.stack(loss_list))


        temp_x = x
        for i, layer in enumerate(self.layers_list):
            if layer == 'Save':
                temp_x = x
            elif layer == 'Conv2D':
                if skip_1:
                    skip_1 = False
                else:
                    if self.mode == 'one':
                        if cnn_para[cnn_counter + 1] is not None:
                            x = F.conv2d(x, cnn_para[cnn_counter][0], bias=cnn_para[cnn_counter + 1][0],
                                         stride=self.layers_info[i][5], padding=self.layers_info[i][6])
                        else:
                            x = F.conv2d(x, cnn_para[cnn_counter][0], bias=None,
                                         stride=self.layers_info[i][5], padding=self.layers_info[i][6])
                        cnn_counter += 2
                    elif self.mode == 'batch':
                        # batch of input, batch of params, batch of output
                        x = self.cnn_forward(x, cnn_para[cnn_counter], cnn_para[cnn_counter + 1],
                                         self.layers_info[i])
                        cnn_counter += 2
                    elif self.mode == 'average':
                        # batch of input, batch para, one hyper. save time
                        x = self.average_cnn_forward(x, cnn_para[cnn_counter], cnn_para[cnn_counter + 1],
                                         self.layers_info[i])
                        cnn_counter += 2
                    else:
                        raise ValueError("Invalid mode")
                        pass
            elif layer == 'Linear':
                if need_flatten:
                    x = x.view(x.size(0), -1)
                    need_flatten = False
                if self.mode == 'one':
                    if linear_para[linear_counter + 1] is not None:
                        x = F.linear(x, linear_para[linear_counter][0], bias=linear_para[linear_counter + 1][0])
                    else:
                        x = F.linear(x, linear_para[linear_counter][0], bias=None)
                    linear_counter += 2
                elif self.mode == 'batch':
                    # batch of input, batch of params, batch of output
                    x = self.linear_forward(x, linear_para[linear_counter], linear_para[linear_counter + 1])
                    linear_counter += 2
                elif self.mode == 'average':
                    # batch of input, batch para, one hyper. save time
                    x = self.average_linear_forward(x, linear_para[linear_counter], linear_para[linear_counter + 1])
                    linear_counter += 2
                else:
                    raise ValueError("Invalid mode")
                    pass
            elif layer == 'ReLU':
                x = F.relu(x)
            elif layer == 'LeakyReLU':
                x = self.leaky_relu[leaky_counter](x)
                leaky_counter += 1
            elif layer == 'AdaptiveAvgPool2D':
                x = self.average_pool[adaptive_pool_counter](x)
                adaptive_pool_counter += 1
            elif layer == 'BatchNorm2D':
                if skip_2:
                    skip_2 = False
                else:
                    x = self.batch_norm[batch_norm_counter](x)
                    batch_norm_counter += 1
            elif layer == 'MaxPool2D':
                x = self.max_pool[max_pool_counter](x)
                max_pool_counter += 1
            elif layer == 'Dropout':
                x = F.dropout(x, p=self.layers_info[i][0], training=train)
            elif layer == 'Downsample':
                skip_1 = True
                skip_2 = True
                # skip the next conv layer and BN layer.
                # next two layers are down sample layers, one down sample and one .
                if self.mode == 'batch':
                    shortcut = self.cnn_forward(temp_x, cnn_para[cnn_counter], cnn_para[cnn_counter + 1],
                                            self.layers_info[i + 1])
                elif self.mode == 'average':
                    shortcut = self.average_cnn_forward(temp_x, cnn_para[cnn_counter], cnn_para[cnn_counter + 1],
                                            self.layers_info[i + 1])
                elif self.mode == 'one':
                    if cnn_para[cnn_counter + 1] is not None:
                        shortcut = F.conv2d(temp_x, cnn_para[cnn_counter][0], bias=cnn_para[cnn_counter + 1][0],
                                            stride=self.layers_info[i + 1][5], padding=self.layers_info[i + 1][6])
                    else:
                        shortcut = F.conv2d(temp_x, cnn_para[cnn_counter][0], bias=None,
                                            stride=self.layers_info[i + 1][5], padding=self.layers_info[i + 1][6])
                else:
                    raise ValueError("Invalid mode")
                    pass
                cnn_counter += 2
                shortcut = self.batch_norm[batch_norm_counter](shortcut)
                batch_norm_counter += 1
                x = x + shortcut
                # x = F.relu(x + shortcut)
                # relu probably decrease.
            else:
                print("Unrecognized layer")
                pass
        if self.extra_loss:
            return x, rotation_loss
        else:
            return x

    def cnn_forward(self, x, cnn_filters, cnn_bias, extra_info):
        result = []
        strides = extra_info[5]
        paddings = extra_info[6]
        x = x.unsqueeze(1)
        for i in range(x.shape[0]):
            if cnn_bias is None:
                output = F.conv2d(x[i], cnn_filters[i], bias=None, stride=strides, padding=paddings)
            else:
                output = F.conv2d(x[i], cnn_filters[i], bias=cnn_bias[i], stride=strides, padding=paddings)
            result.append(output)
        final_result = torch.cat(result, dim=0)
        return final_result

    def linear_forward(self, x, linear_weight, linear_bias):
        result = []
        x = x.unsqueeze(1)
        for i in range(x.shape[0]):
            if linear_bias is None:
                output = F.linear(x[i], linear_weight[i], bias=None)
            else:
                output = F.linear(x[i], linear_weight[i], bias=linear_bias[i])
            result.append(output)
        final_result = torch.cat(result, dim=0)
        return final_result
    def average_cnn_forward(self, x, cnn_filters, cnn_bias, extra_info):
        strides = extra_info[5]
        paddings = extra_info[6]
        if cnn_bias is None:
            output = F.conv2d(x, torch.mean(cnn_filters, dim=0), bias=None, stride=strides, padding=paddings)
        else:
            output = F.conv2d(x, torch.mean(cnn_filters, dim=0), bias=torch.mean(cnn_bias, dim=0), stride=strides,
                              padding=paddings)
        return output

    def average_linear_forward(self, x, linear_weight, linear_bias):
        if linear_bias is None:
            output = F.linear(x, torch.mean(linear_weight, dim=0), bias=None)
        else:
            output = F.linear(x, torch.mean(linear_weight, dim=0), bias=torch.mean(linear_bias, dim=0))
        return output




if __name__ == '__main__':
#     pr = 0
#     shell = torchvision.models.resnet18()
    shell = ResNet18_cifar10()
    cnn_original, linear_original = ShellParser_Layers(shell)
    conv_1, conv_2 = ShellParser(shell, general=1)

    # a = ShellParser(shell, general=2)
    # print(a)
    print()
    # print(shell.fc.weight.shape)
    # print(conv_1)
    # print(conv_2)
    example_cifar= torch.randn(3, 3, 32,32)
    a = ShellParser(shell)
    individual_count, layers, layers_info = a
    print(layers_info,"layers_info\n\n")
    print(layers,"layers\n\n")
    shell_2 = torchvision.models.resnet18()
    _, layers_2, layers_info_2 = ShellParser(shell_2)
    print(layers_info_2,"layers_info\n\n")
    print(layers_2,"layers\n\n")
    # print(layers_info)
    b = hyper_goal_output(layers_info)
    # print(individual_count, layers, layers_info)
    for elements in layers_info:
        if elements is not None:
            if len(elements) == 4:
                a = int(math.sqrt(elements[1] / elements[2]))
                # print(a)
    permutation_list = group_permutation_element(a, -1)
    # total_count = sum(individual_count)
    # input_size = 32 * 32
    # e_90 = torch.rot90(example_cifar, 1, dims=[2, 3])
    # e_45 = ttf.rotate(example_cifar, 45)
    # e_135 = ttf.rotate(example_cifar, 135)
    hyper_network = HyperNetwork_Head(shell, shared_layer=3, lora=True, inter=2, shared_choice=0)
    cnn, linear= equi_loader_general(hyper_network, layers_info, example_cifar, permutation_list, b, head=True )
    p_loss = param_matching_loss(shell, layers_info, permutation_list, head = True)
    loss = p_loss(hyper_network, example_cifar)
    print(loss)
    functional1 = FunctionalFullNetwork(hyper_network, shell, 0, mode='one', head=True)
    output1 = functional1(example_cifar)
    print(output1.shape)
    loss_1 = p_loss(functional1.hypernetwork, example_cifar)
    print(loss_1)

# for i in linear:
    #     print(i.shape)
    # for j in cnn:
    #     print(j.shape)
    # print(len(cnn_original), len(cnn), "cnn compare")
    # print(len(linear_original), len(linear),"linear compare")
    # assert len(cnn_original)== len(cnn)
    # assert  len(linear_original) == len(linear)
    # for i in range(len(cnn_original)):
    #     cnn[i] = torch.mean(cnn[i], dim=0, keepdim=False)
    #     print(tuple(cnn_original[i].shape)== tuple(cnn[i].shape),"cnn shape compare")
    # for i in range(len(linear_original)):
    #     print(tuple(linear_original[i].shape)==tuple(linear[i].shape),"linear shape compare")
    # print(cnn_original[0].shape)
    # print(cnn[0].shape)

    # for i in example_parameters:
    #     if i is not None:
    #         print(i.shape)
    #
    # for name, para in shell.named_parameters():
    #     print( para.shape, name)
    # total_params = sum(p.numel() for p in hyper_network.parameters())
    # print("Total number of parameters: \n{:,}".format(total_params))
    # print("vs")
    # print("{:,}".format(sum(individual_count)))
    # print(hyper_network.d, hyper_network.e)
    # goal_list = hyper_goal_output(layers_info)
    # output = shell(example_cifar)


# functional3 = FunctionalFullNetwork(hyper_network, shell, n, mode='batch')
    # print(output1.shape)
    # output3 = functional3(example_cifar)
    # print(output.shape, output1.shape, output3.shape)
