import torch
import torch.nn as nn
import math
from utils.model_tools import ShellParser
import torchvision.transforms as transforms

"""
Unlike MLP, we can't simply rotate it using rot90(), thus this.
find_sqrt: find the square root of a number and check if it is an integer. If not, raise an error.
math_permute_list: Compute the list needed to perform permutation on MLP, based on  CNN rotations.
rotate_matrix: Writing out torch.rot90() function to understand the rotation. Don't use it.
Permute Flatten: Based on the pattern from rotate_matrix, we write how to permute the flattened tensor. Not efficient, use permute_MLP instead.
"""


def find_sqrt(num):
    sqrt_result = math.sqrt(num)
    if sqrt_result.is_integer():
        return int(sqrt_result)
    else:
        raise ValueError("Dimension of input tensor length is not a perfect square. Got {} instead.".format(num))

def functional_permute_list(tensor: torch.Tensor, d: int, direction: int):
    if direction == 0:
        return tensor
    elif direction < 0:
        if len(tensor.shape) == 3:
            for _ in range(abs(direction)):
                tensor = tensor[:, :, math_permute_list(d, -1)]
            return tensor
        elif len(tensor.shape) == 2:
            for _ in range(abs(direction)):
                tensor = tensor[:, math_permute_list(d, -1)]
            return tensor
    elif direction > 0:
        if len(tensor.shape) == 3:
            for _ in range(direction):
                tensor = tensor[:, :, math_permute_list(d, 1)]
            return tensor
        if len(tensor.shape) == 2:
            for _ in range(direction):
                tensor = tensor[:, math_permute_list(d, 1)]
            return tensor


def math_permute_list(d, direction, view_mode=0):
    """
    d*d is the dimension of the matrix,
    direction = -1, 1 or 0.
    Set view_mode = 1 , if you want to visualize starting from index 1 or if you want to figure out the permutation.
    Set view_mode = 0 for using in tensor[:,permute_list].
    """
    permute_list = []
    if direction == -1:
        for i in range(d ** 2):
            column = i % d
            row = i // d
            permute_list.append(d * (d - 1) - d * column + row + view_mode)
            # print("{} goes to {}th position.".format(d*(d-1)-d*column+row,i))
        return permute_list
    elif direction == 0:
        permute_list = list(range(d * d))
        return [x + view_mode for x in permute_list]
    elif direction == 1:
        for i in range(d ** 2):
            column = i % d
            row = i // d
            permute_list.append(d * column + (d - 1 - row) + view_mode)
        return permute_list
        # print("{} goes to {}th position.".format(d * column + (d-1- row), i))


def group_permutation_element(d, direction):
    """
    we already have a list for switching columns, See math_permute_list.
    This returns the mathematical way of permuting.
    """
    direction = -1 * direction
    permute_list = math_permute_list(d, direction)
    generator = 0
    result = []
    one_permute = []
    while True:
        if generator in one_permute:
            result.append(one_permute)
            # go to the next one not counted yet.
            one_permute = []
            remaining = [x for x in permute_list if x not in [y for sublist in result for y in sublist]]
            if remaining == []:
                break
            generator = min(remaining)
        else:
            one_permute.append(generator)
            generator = permute_list[generator]
    return result

def unnormalize_image(images, mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]):
    image_detach = images.detach()
    for t, m, s in zip(image_detach, mean, std):
        t.mul_(s).add_(m)
    return image_detach


def goal_filter_dim(original_dim):
    """
    Our hypernetwork, when generating a filter for CNN, only generate a forth of the filters. if odd, we average out the intersections.
    :param original_dim: this is the dimension of the CNN filters.
    :return: the dimension of the hypernetwork output.
    """
    if original_dim % 2 == 0:
        return int(original_dim / 2)
    elif original_dim % 2 == 1:
        return int((original_dim + 1) / 2)


def goal_linear_dim(original_dim):
    """
    Our hypernetwork, when generating linear weights, all possible permutation has order 4 or order 1 (inv).
    :param original_dim:
    :return:  if even, return dim**2/4. if odd, return (dim**2-1)/4 + 1.
    Note that it is guaranteed that (dim**2-1) is always divisible by 4.

    In summary, we generate a forth of the permuting part and one for the inv part.
    """
    if original_dim % 2 == 0:
        return int(original_dim ** 2 / 4)
    elif original_dim % 2 == 1:
        return int((original_dim ** 2 - 1) / 4 + 1)


def rotate_martix(matrix, direction):
    """
    just writing out the torch.rot90() function to understand the rotation.
    Only use for writing the code for permute_MLP.
    Don't use it.
    """
    d: int = matrix.shape[0]
    c = torch.zeros_like(matrix)
    if direction == -1:
        for i in range(matrix.shape[1]):
            for j in range(d):
                c[i][j] = matrix[d - 1 - j][i]
        return c
    elif direction == 1:
        for i in range(d):
            for j in range(d):
                c[i][j] = matrix[j][d - 1 - i]
        return c
    elif direction == 0:
        return matrix


def permute_flatten(vec, direction):
    d = torch.zeros_like(vec)
    assert vec.dim() == 1 or 2, "Only support 1D or 2D tensor."
    if vec.dim() == 1:
        d = target_permute(vec, direction)
    else:
        for i in range(vec.shape[0]):
            d[i] = target_permute(vec[i], direction)
    return d


def target_permute(vec, direction) -> torch.Tensor:
    result = torch.zeros_like(vec)
    d = find_sqrt(vec.shape[0])
    if direction == -1:
        for i in range(vec.shape[0]):
            column = i % d
            row = i // d
            result[i] = vec[d * (d - 1) - d * column + row]
        # print("{} goes to {}th position.".format(d*(d-1)-d*column+row,i))
        return result
    elif direction == 0:
        return vec
    elif direction == 1:
        for i in range(vec.shape[0]):
            column = i % d
            row = i // d
            result[i] = vec[d * column + (d - 1 - row)]
        # print("{} goes to {}th position.".format(d * column + (d-1- row), i))
        return result
    else:
        raise ValueError("Direction more than ±90 is not implemented. Use direction = -1, 0, or 1.")


def manual_calculate_parameters(input_dim, total_parameters, embedding_dim, chunk):
    """For embedding case, compute the parameters count. """
    return (embedding_dim + 1) * total_parameters / chunk + chunk * (embedding_dim + 1 + input_dim)

def find_closest_divisor(input_dim, total_parameters, embedding_dim):
    '''
    a not so accurate way to find chunks to argmin parameter count.
    '''

    argmin = math.sqrt((embedding_dim + 1) * total_parameters / (embedding_dim + 1 + input_dim))
    a_int = round(argmin)  # Convert argmin to the nearest integer.
    return a_int

    # Based on the formula of the parameters, automatically find the divisor of total shell parameters
    # that is closest to the argmin of hypernetwork parameter count.
    # Compare two potential results and return the one with smaller parameter count.
    #
    # Currently, we want to minimize as much as possible, so we just use the root.
    # if total_parameters % a_int == 0:
    #     return a_int  # If a_int is already a divisor, return it.
    #
    # # Search upwards and downwards for the closest divisor.
    # up = a_int + 1
    # down = a_int - 1
    # possible_divisors = [0, 0]
    # while possible_divisors[0] == 0 or possible_divisors[1] == 0:
    #     if total_parameters % up == 0:
    #         possible_divisors[0] = up
    #         up = -1
    #     if down > 0 and total_parameters % down == 0:
    #         possible_divisors[1] = down
    #         down = -1
    #     if up > 0:
    #         up += 1
    #     if down > 0:
    #         down -= 1
    # result_0 = manual_calculate_parameters(input_dim, total_parameters, embedding_dim, chunk=possible_divisors[0])
    # result_1 = manual_calculate_parameters(input_dim, total_parameters, embedding_dim, chunk=possible_divisors[1])
    # if result_0 < result_1:
    #     return possible_divisors[0]
    # else:
    #     return possible_divisors[1]


class RotateTransform:
    """
    rotate the image by angle degrees, in transforms.
    """

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, x):
        return transforms.functional.rotate(x, self.angle)


if __name__ == '__main__':
    print("Welcome to math_tools.py")
    dim = 3
    # math_list = math_permute_list(dim, 1, view_mode=1)
    # for i in range(dim**2):
    #     print("{} goes to {}th position.".format(math_list[i], i+1))
    # print(math_list)
    # print(math_permute_list(dim, 1))
    # print(group_permutation_element(dim, 1))
    a = int(math.sqrt(5408 / 8))
    print(a)
    example_list = group_permutation_element(a, -1)
    print(example_list)
    print(len(example_list))
    print()
