from math import exp, cos, pi, sqrt

import torch
from torch import nn, Tensor


class UnitaryMatrixMultiplication(nn.Module):

    def __init__(self, num_features: int, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.num_features = num_features
        self.weight = nn.Parameter(
            torch.empty((num_features, num_features), **factory_kwargs)
        )
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=sqrt(5))

    def forward(self, x: Tensor) -> Tensor:
        W = (
            self.weight - torch.transpose(self.weight, 0, 1)
        ) / 2  # a skew-symmetric matrix W
        unitary = torch.linalg.matrix_exp(torch.neg(W))
        product = torch.einsum("ij,bi...->bj...", unitary, x)
        return product

    def __repr__(self):
        return f"{self.__class__.__name__}(num_features={self.num_features})"


class NonNegativeLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(NonNegativeLinear, self).__init__(in_features, out_features, bias=bias)

    def forward(self, input):
        self.weight.data = torch.abs(self.weight.data)
        return super(NonNegativeLinear, self).forward(input)


class GumbelScheduler:
    def __init__(
        self,
        model,
        name_attr,
        start_factor,
        end_factor,
        total_iters,
        annealing_type="linear",
        last_iter=-1,
    ):
        """
        A class to manage the evolution of a parameter over training iterations using various annealing strategies.

        Parameters:
            model (object): The model whose attribute needs to be updated.
            name_attr (str): The name of the attribute to be updated in the model.
            start_factor (float): The initial value of the attribute.
            end_factor (float): The final value of the attribute.
            total_iters (int): The total number of training iterations.
            annealing_type (str): The type of annealing strategy to use. Default is 'linear'.
            last_iter (int): The index of the last iteration. Default is -1.
        """
        self.model = model
        self.name_attr = name_attr
        self.start_factor = start_factor
        self.end_factor = end_factor
        self.total_iters = total_iters
        self.annealing_type = annealing_type
        self.last_iter = last_iter

    def get_value(self):
        """
        Compute the value of the parameter based on the current annealing strategy and iteration index.

        Returns:
            float: The computed value of the parameter.
        """

        last_iter = max(self.last_iter, 0)

        if self.annealing_type == "linear":
            # Check if the attribute already exists in this class
            if not hasattr(self, "min_"):
                setattr(self, "min_", min(self.start_factor, self.end_factor))
                setattr(self, "max_", max(self.start_factor, self.end_factor))

            value = self.end_factor + (self.start_factor - self.end_factor) * (
                1 - last_iter / self.total_iters
            )
            return value if self.min_ <= value <= self.max_ else self.end_factor
        elif self.annealing_type == "constant":
            return self.start_factor
        elif self.annealing_type == "exponential":
            return self.end_factor + (self.start_factor - self.end_factor) * exp(
                -6.0 * last_iter / self.total_iters
            )
        elif self.annealing_type == "cosine":
            if hasattr(self, "value_cosine"):
                return self.value_cosine

            value = self.end_factor + 0.5 * (self.start_factor - self.end_factor) * (
                1 + cos(pi * last_iter / self.total_iters)
            )
            if abs(value - self.end_factor) < 1e-5:
                setattr(self, "value_cosine", self.end_factor)
            return value
        else:
            raise ValueError(
                "Invalid annealing type. Choose from 'linear', 'constant', 'exponential', or 'cosine'."
            )

    def step(self):
        """
        Update the parameter value based on the current iteration index and annealing strategy.
        """
        self.last_iter += 1
        setattr(self.model, self.name_attr, self.get_value())


def train_phase(module, changed_layers):
    """
    Function to set the training mode for the given module and its children layers based on changed layers.

    Args:
    - module: torch.nn.Module: The module to set the training mode for.
    - changed_layers: list of str: List of layer names that need to be trained.
    """
    stack = [(module, "")]
    while stack:
        module, prefix = stack.pop()
        for name, child in module.named_children():
            child_name = f"{prefix}.{name}" if prefix else name
            child.train(
                mode=any([child_name.startswith(layer) for layer in changed_layers])
            )
            stack.append((child, child_name))


def test_training_children_layers(module):
    """
    Function to iterate over the children layers of a module and yield their names and module objects.

    Args:
    - module: torch.nn.Module: The module whose children layers need to be iterated over.

    Yields:
    Tuple[str, torch.nn.Module]: A tuple containing the name and the module object of each child layer.
    """
    stack = [(module, "")]
    while stack:
        module, prefix = stack.pop()
        for name, child in module.named_children():
            child_name = f"{prefix}.{name}" if prefix else name
            yield child_name, child
            stack.append((child, child_name))


# # train_phase(model, model.changed_layers)
# model.train()
# for name, module in test_training_children_layers(model):
#     if module.training:
#         print(f"Name: {name} | {module.training}")
#     # print(f"Name: {name} | {module.training}")
