from typing import List
import torch
import torch.nn as nn
import numpy as np
from utils.math_tools import functional_permute_list, group_permutation_element, goal_linear_dim
import time

def pause(seconds):
    time.sleep(seconds)
    return
def equivariant_filters_2(example_input, hypernet, printing=False):
    '''
    We can potentially learn more per filter, by combining the coefficients of the 4 rotations.
    '''
    if hypernet.out_features == 2:
        example_1 = example_input.reshape(example_input.shape[0], -1)
        example_2 = torch.rot90(example_input, 1, [2, 3]).reshape(example_input.shape[0], -1)
        example_3 = torch.rot90(example_input, 2, [2, 3]).reshape(example_input.shape[0], -1)
        example_4 = torch.rot90(example_input, 3, [2, 3]).reshape(example_input.shape[0], -1)

        para_1 = hypernet(example_1)[:, 0]
        para_2 = hypernet(example_2)[:, 0]
        para_3 = hypernet(example_3)[:, 0]
        para_4 = hypernet(example_4)[:, 0]
        bias_1 = hypernet(example_1)[:, 1]
        bias_2 = hypernet(example_2)[:, 1]
        bias_3 = hypernet(example_3)[:, 1]
        bias_4 = hypernet(example_4)[:, 1]

        if printing:
            pause(1)

            for i in range(4):
                print("For {} rotation, the generated parameter piece is: {}".format(i*90, para_1 if i == 0 else para_2 if i == 1 else para_3 if i == 2 else para_4))
        before_reshape = torch.cat([para_1, para_2, para_4, para_3], dim=0)

        output = before_reshape.reshape(example_input.shape[0], 2, 2)
        bias = (bias_1 + bias_2 + bias_4 + bias_3) / 4
        # bias = bias.reshape(example_input.shape[0], 1)
        return output, bias

    elif hypernet.out_features == 1:
        example_1 = example_input.reshape(example_input.shape[0], -1)
        example_2 = torch.rot90(example_input, 1, [2, 3]).reshape(example_input.shape[0], -1)
        example_3 = torch.rot90(example_input, 2, [2, 3]).reshape(example_input.shape[0], -1)
        example_4 = torch.rot90(example_input, 3, [2, 3]).reshape(example_input.shape[0], -1)
        para_1 = hypernet(example_1)
        para_2 = hypernet(example_2)
        para_3 = hypernet(example_3)
        para_4 = hypernet(example_4)
        order1 = para_1
        order2 = para_2
        order3 = para_4
        order4 = para_3

        before_reshape = torch.stack([order1, order2, order3,order4], dim=1)
        output = before_reshape.reshape(example_input.shape[0], 2, 2)
        return output


def equivariant_filters_3(example_input, network, printing=False):
    """
    It may be provable by induction.
    """
    example_1 = example_input.reshape(example_input.shape[0], -1)
    example_2 = torch.rot90(example_input, 1, [2, 3]).reshape(example_input.shape[0], -1)
    example_3 = torch.rot90(example_input, 2, [2, 3]).reshape(example_input.shape[0], -1)
    example_4 = torch.rot90(example_input, 3, [2, 3]).reshape(example_input.shape[0], -1)
    para_1 = network(example_1).reshape(example_input.shape[0], 2, 2)
    '''shape batch x 2x2'''
    para_2 = torch.rot90(network(example_2).reshape(example_input.shape[0], 2, 2), -1, [1, 2])
    para_3 = torch.rot90(network(example_3).reshape(example_input.shape[0], 2, 2), -2, [1, 2])
    para_4 = torch.rot90(network(example_4).reshape(example_input.shape[0], 2, 2), -3, [1, 2])

    upper_half = torch.cat((para_1[:, :, [0]], (para_1[:, :, [1]] + para_2[:, :, [0]]) / 2, para_2[:, :, [1]]), dim=2)
    lower_half = torch.cat((para_4[:, :, [0]], (para_4[:, :, [1]] + para_3[:, :, [0]]) / 2, para_3[:, :, [1]]), dim=2)
    output = torch.cat(
        (upper_half[:, [0], :], (upper_half[:, [1], :] + lower_half[:, [0], :]) / 2, lower_half[:, [1], :]), dim=1)
    if printing:
        pause(1)

        for i in range(4):
            print("For {} rotation, the generated parameter piece is: \n{}".format(i*90, para_1 if i == 0 else para_2 if i == 1 else para_3 if i == 2 else para_4))

    return output


def equivariant_linear(example_input, network, dims: int):
    """
    goal_dims = reduced(previous_filter_dims)
    The output should be in batch_size x output channel x (filter^2*previous_output_channel)
    """
    example_1 = example_input.reshape(example_input.shape[0], -1)
    example_2 = torch.rot90(example_input, 1, [2, 3]).reshape(example_input.shape[0], -1)
    example_3 = torch.rot90(example_input, 2, [2, 3]).reshape(example_input.shape[0], -1)
    example_4 = torch.rot90(example_input, 3, [2, 3]).reshape(example_input.shape[0], -1)
    permutation_list: List[List[int]] = group_permutation_element(dims, -1)
    final = [0]*(dims**2)

    for i, one_permute in enumerate(permutation_list):
        if len(one_permute) == 1:
            para = [network(example_1)[:, i], network(example_2)[:, i], network(example_3)[:, i],
                    network(example_4)[:, i]]
            final[one_permute[0]] = sum(para) / 4
            '''inv part, take average'''
        elif len(one_permute) == 4:
            para = [network(example_1)[:, i], network(example_2)[:, i], network(example_3)[:, i], network(example_4)[:, i]]
            """
            we  reorder it. 
            """
            # bias_1 = network(example_1)[:, 1]
            # bias_2 = network(example_2)[:, 1]
            # bias_3 = network(example_3)[:, 1]
            # bias_4 = network(example_4)[:, 1]
            for i in range(4):
                final[one_permute[i]] = para[i]
        # probably concate or some
    # bias = (bias_1 + bias_2 + bias_4 + bias_3) / 4
    final = torch.stack(final, dim=1)
    return final



def target_demo(demo):
    import random

    assert demo in ['2x2', '3x3', 'linear'], "Demo should be 2x2, 3x3, or linear."

    print("We are use the {} demo.".format(demo), " For simplicity, we assume the channel is 1. ")
    pause(1)
    np.set_printoptions(precision=3)
    random.seed(0)
    torch.manual_seed(0)
    batch_size = 1

    if demo == '2x2':
        '''This blocked code is for 2x2'''
        bias_list = []
        example_input = torch.randn(batch_size, 1, 10, 10, requires_grad=True)
        print("To generate a 2x2 filter, each rotated version of input is in charge of the 1/4 of the filter,"
              " which is just one value. As well as one value for the bias. ")

        print("The NEP generator is set to be simply a linear layer. It outputs ")
        w_b_network = nn.Linear(100, 2)
        print()
        for i in range(4):
            input = torch.rot90(example_input, i, [2, 3])
            output, bias = equivariant_filters_2(input, w_b_network, printing=True if i == 0 else False)
            pause(1)

            print("\nRotating Input of {} degrees:".format(i * 90))
            print("Output filter:")
            print(output[0])
            # print(torch.rot90(output, -i, [1, 2]), " when rotating back.")
            print("Bias value:")
            print(bias)


    elif demo == '3x3':
        '''this code is for 3x3'''
        example_input = torch.randn(batch_size, 1, 10, 10, requires_grad=True)
        w_network = nn.Linear(100, 4)
        # output = equivariant_filters_3(example_input, w_network)

        for i in range(4):
            input = torch.rot90(example_input, i, [2, 3])
            output = equivariant_filters_3(input, w_network, printing=True if i == 0 else False)
            pause(1)

            print("\nRotating Input of {} degrees:".format(i * 90))
            print("Output filter:")
            print(output[0].detach().numpy(), " is the output 3x3 filters.")
            # print(torch.rot90(output, -i, [1, 2]).detach().numpy(), " when rotating back.")
            print()


    elif demo == 'linear':
        np.set_printoptions(precision=3)
        random.seed(0)
        torch.manual_seed(0)
        example_filter_dim = 3
        goal_dims = goal_linear_dim(example_filter_dim)
        example_input = torch.rand(batch_size, 1, 10, 10, requires_grad=True)
        network = nn.Linear(100, goal_dims)
        output = equivariant_linear(example_input, network, example_filter_dim)
        for i in range(4):
            print("Input with rotation of {} degrees:".format(i * 90))
            inp = torch.rot90(example_input, i, [2, 3])
            output = equivariant_linear(inp, network, example_filter_dim)
            print(output.detach().numpy(), " result.")
            print()
        print("For easy comparing, we also permutate back, and should be the same after permuting back.")
        for i in range(4):
            print("Input with rotation of {} degrees:".format(i * 90))
            inp = torch.rot90(example_input, i, [2, 3])
            output = equivariant_linear(inp, network, example_filter_dim)
            print(functional_permute_list(output, example_filter_dim, -1 * i).detach().numpy(),
                  "permute back for visualization.")


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--demo', type=str, default='linear', help='2x2, 3x3, linear')
    args = parser.parse_args()
    print("Note all rotation are counter-clockwise.")
    target_demo(args.demo)