import abc
import copy

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

from algorithms.convergence_algorithms.utils import ball_perturb
from algorithms.mapping.trust_region import StaticShrinkingTrustRegion


class RotationTrustRegion(StaticShrinkingTrustRegion):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rotation_matrix = None

    def normalize_to_real(self, data: Tensor) -> Tensor:
        if self.rotation_matrix is not None:
            data = (self.rotation_matrix @ data.T).T
        return torch.tanh(data)

    def normalize_to_unreal(self, data: Tensor) -> Tensor:
        data = torch.clamp(data, min=-1 + 1e-16, max=1 - 1e-16)
        data = torch.arctanh(data)
        if self.rotation_matrix is not None:
            rotation_matrix_inverse = torch.inverse(self.rotation_matrix)
            data = (rotation_matrix_inverse @ data.T).T
        return data

    def squeeze(self, best_result, grad_net=None, epsilon=None, **kwargs):
        self.rotation_matrix = self.create_rotation_matrix(grad_net, epsilon)
        super().squeeze(best_result, **kwargs)

    @abc.abstractmethod
    def create_rotation_matrix(self, grad_net: Module, epsilon: float) -> Tensor:
        raise NotImplementedError()


class GradientScalingTrustRegion(RotationTrustRegion):
    def create_rotation_matrix(self, grad_net: Module, epsilon: float) -> Tensor:
        curr_point = self.normalize_to_unreal(self.mu)
        curr_grad = grad_net(curr_point)
        scaling = curr_grad.max() / curr_grad
        return torch.diag(scaling).to(device=self.device)


class TrainedRotationTrustRegion(RotationTrustRegion):
    def __init__(
        self,
        *args,
        # num_steps: int = 1000,
        # batch_size: int = 64,
        # dont_lower_scaling: bool = True,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.num_steps = 1000
        self.batch_size = 64
        self.dont_lower_scaling = True

    @abc.abstractmethod
    def create_parameters(self) -> Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def create_training_matrix(self, params: Tensor) -> Tensor:
        raise NotImplementedError()

    def create_rotation_matrix(self, grad_net: Module, epsilon: float) -> Tensor:
        params = self.create_parameters()

        samples = ball_perturb(self.mu, epsilon, 1000, device=self.device)
        dataset = TensorDataset(samples)
        opt = Adam([params])
        tr_copy = copy.copy(self)
        for i in range(self.num_steps):
            sampler = RandomSampler(dataset)
            loader = DataLoader(dataset, sampler=sampler, batch_size=self.batch_size)

            for batch in loader:
                real_points = batch[0]
                tr_copy.rotation_matrix = self.create_training_matrix(params)

                unreal_samples = tr_copy.normalize_to_unreal(real_points)
                gradients = grad_net(unreal_samples)
                normalized_gradients = gradients / gradients.norm(dim=1).unsqueeze(1)
                loss = torch.nn.functional.mse_loss(
                    normalized_gradients, torch.ones_like(normalized_gradients)
                ).mean()
                loss.backward()
                opt.step()
                opt.zero_grad()
        final_matrix = self.create_training_matrix(params)
        if not self.dont_lower_scaling:
            min_val = final_matrix.diag().min().clip(min=1e-4)
            final_matrix = final_matrix * (torch.eye(final_matrix.shape[0]) / min_val)
        return final_matrix


class TrainedMatrixTrustRegion(TrainedRotationTrustRegion):
    def create_parameters(self) -> Tensor:
        dim = self.mu.shape[0]
        return torch.rand(dim, dim, device=self.device, requires_grad=True)

    def create_training_matrix(self, params: Tensor) -> Tensor:
        return (params + params.T) / 2


class ScaledTrainedRotationTrustRegion(TrainedRotationTrustRegion):
    def create_parameters(self) -> Tensor:
        dim = self.mu.shape[0]
        return torch.rand(dim, device=self.device, requires_grad=True)

    def create_training_matrix(self, params: Tensor) -> Tensor:
        dim = self.mu.shape[0]
        return torch.eye(dim, device=params.device) * params
