from collections import OrderedDict
from torch import nn
import torch


def make_cnn():
    """Define a small CNN whose weight space we will process with an NFN."""
    return nn.Sequential(
        nn.Conv2d(1, 32, 3), nn.ReLU(),
        nn.Conv2d(32, 32, 3), nn.ReLU(),
        nn.Conv2d(32, 32, 3), nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Linear(32, 10)
    )

def make_siren():
    """Define a small CNN whose weight space we will process with an NFN."""
    return nn.Sequential(
        nn.Conv2d(1, 32, 3), nn.ReLU(),
        nn.Conv2d(32, 32, 3), nn.ReLU(),
        nn.Conv2d(32, 32, 3), nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Linear(32, 10)
    )

def strip_cnn_struture(input_network):
    """Define a small CNN whose weight space we will process with an NFN."""
    layers = []
    for layer in input_network:
        if isinstance(layer, nn.Conv2d):
            layers.append(nn.Linear(layer.in_channels, layer.out_channels))
        else:
            layers.append(layer)
    return nn.Sequential(*layers)
    # return nn.Sequential(
    #     nn.Linear(1, 32), nn.ReLU(),
    #     nn.Linear(32, 32), nn.ReLU(),
    #     nn.Linear(32, 32), nn.ReLU(),
    #     nn.AdaptiveAvgPool2d(1),
    #     nn.Flatten(),
    #     nn.Linear(32, 10)
    # )


def check_same(cnn1, cnn2):
    inp = torch.randn(20, 1, 28, 28)
    out1 = cnn1(inp)
    out2 = cnn2(inp)
    return torch.allclose(out1, out2)


def sample_perm(state_dict: OrderedDict):
    prev_perm = None
    i = 0
    sd_list = list(state_dict.items())
    permuted_sd = OrderedDict()
    while i < len(state_dict):
        weight_key, weight = sd_list[i]
        bias_key, bias = sd_list[i + 1]
        if prev_perm is not None:
            weight = weight[:, prev_perm]
        if i + 2 < len(state_dict):
            this_perm = torch.randperm(weight.shape[0])
            weight = weight[this_perm]
            bias = bias[this_perm]
            prev_perm = this_perm
        permuted_sd[weight_key] = weight
        permuted_sd[bias_key] = bias
        i += 2
    return permuted_sd

def sample_perm_scale(state_dict: OrderedDict):
    prev_perm = None
    prev_scale = None
    i = 0
    sd_list = list(state_dict.items())
    permuted_sd = OrderedDict()
    
    while i < len(state_dict):
        weight_key, weight = sd_list[i]
        bias_key, bias = sd_list[i + 1]

        if prev_perm is not None and prev_scale is not None:
            weight = weight[:, prev_perm]
            weight = (weight.transpose(1, -1) * prev_scale**(-1)).transpose(1, -1)

        if i + 2 < len(state_dict):
            this_perm = torch.randperm(weight.shape[0])
            this_scale = 1/2 * torch.abs(torch.rand(weight.shape[0]))
            weight = weight[this_perm]
            weight = (weight.transpose(0, -1) * this_scale).transpose(0, -1)
            
            bias = bias[this_perm]
            bias = (bias.transpose(0, -1) * this_scale).transpose(0, -1)
            
            prev_perm = this_perm
            prev_scale = this_scale

        permuted_sd[weight_key] = weight
        permuted_sd[bias_key] = bias
        i += 2

    return permuted_sd

def check_perm_symmetry():
    # Sanity check that we are permuting CNN channels properly.
    cnn1, cnn2 = make_cnn(), make_cnn()
    cnn2.load_state_dict(sample_perm(cnn1.state_dict()))
    return check_same(cnn1, cnn2)

def check_perm_scale_symmetry():
    # Sanity check that we are permuting CNN channels properly.
    cnn1, cnn2 = make_cnn(), make_cnn()
    cnn2.load_state_dict(sample_perm_scale(cnn1.state_dict()))
    return check_same(cnn1, cnn2)
