import math
from typing import Callable, Iterable, Tuple

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Adam
from torch.utils.data import TensorDataset, RandomSampler, DataLoader

from algorithms.convergence_algorithms.typing import BoundedEvaluatedSpace
from algorithms.mapping.base import InputMapping, OutputMapping
from algorithms.stopping_condition.environment import EnvGoalIsReached, NoMoreBudget
from algorithms.stopping_condition.trsut_region import TrustRegionStopCondition

GOAL_IS_REACHED_STOPPING_CONDITION = EnvGoalIsReached()
NO_MORE_BUDGET_STOPPING_CONDITION = NoMoreBudget()
TRUST_REGION_STOPPING_CONDITION = TrustRegionStopCondition()

ANGLE = 2


def reset_all_weights(model: Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: Module):
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


def ball_perturb(
    ball_center: Tensor,
    eps: float,
    num_samples: int,
    dtype: torch.dtype = torch.float64,
    device: int = None,
) -> Tensor:
    ball_dim_size = ball_center.shape[-1]

    perturb = (
        torch.FloatTensor(num_samples, ball_dim_size).to(device=device, dtype=dtype).normal_()
    )
    mag = torch.FloatTensor(num_samples, 1).to(device=device, dtype=dtype).uniform_()
    perturb = perturb / (torch.norm(perturb, dim=1, keepdim=True) + 1e-8)

    explore = ball_center + eps * mag * perturb
    return explore


def ball_perturb_between_radius(
    ball_center: Tensor,
    min_radius: float,
    max_radius: float,
    num_samples: int,
    dtype: torch.dtype = torch.float64,
    device: int = None,
):
    ball_dim_size = ball_center.shape[-1]

    perturb = torch.rand(num_samples, ball_dim_size).to(device=device, dtype=dtype).normal_()
    mag = torch.rand(num_samples, 1).to(device=device, dtype=dtype).uniform_()
    mag = min_radius + (max_radius - min_radius) * mag
    perturb = perturb / (torch.norm(perturb, dim=1, keepdim=True) + 1e-8)

    explore = ball_center + mag * perturb
    return explore


def cone_explore(
    ball_center: Tensor,
    eps: float,
    num_samples: int,
    grad: Tensor,
    dtype: torch.dtype = torch.float64,
    device: int = None,
):
    alpha = math.pi / ANGLE
    ball_dim_size = ball_center.shape[-1]
    ball_center = ball_center.unsqueeze(0)

    x = torch.FloatTensor(num_samples, ball_dim_size, device=device).normal_().to(dtype=dtype)
    mag = torch.FloatTensor(num_samples, 1, device=device).uniform_().to(dtype=dtype)

    x = x / (torch.norm(x, dim=1, keepdim=True) + 1e-8)
    grad = grad / (torch.norm(grad) + 1e-8)

    cos = (x @ grad).unsqueeze(1)

    dp = x - cos * grad.unsqueeze(0)

    dp = dp / torch.norm(dp, dim=1, keepdim=True)

    acos = torch.acos(torch.clamp(torch.abs(cos), 0, 1 - 1e-8))

    new_cos = torch.cos(acos * alpha / (math.pi / 2))
    new_sin = torch.sin(acos * alpha / (math.pi / 2))

    cone = new_sin * dp + new_cos * grad
    explore = ball_center - eps * mag * cone

    return explore


def samples_in_ball(
    ball_center: Tensor,
    eps: float,
    data: Tensor,
    num_of_points: int,
    dtype: torch.dtype = torch.float64,
) -> Tuple[Tensor, Tensor]:
    """
    This function takes samples from an existing database that are in the given ball radius.
    The function returns samples and the database without those samples
    """
    if data.numel() == 0:
        return torch.tensor([], dtype=dtype, device=data.device), data
    distance_from_center = (ball_center - data).pow(2).sum(dim=1).sqrt()
    indices_of_good_points = distance_from_center < eps
    points_in_ball = data[indices_of_good_points]
    points_not_in_ball = data[torch.logical_not(indices_of_good_points)]
    indices_to_chose = torch.randperm(len(points_in_ball))
    samples = points_in_ball[indices_to_chose][:num_of_points]
    remaining_database = points_in_ball[indices_to_chose][num_of_points:]
    return samples, torch.cat((points_not_in_ball, remaining_database))


def random_sampler_loader_from_tensor(
    data: Tensor,
    batch_size: int,
    samples: int,
    num_of_split: int = 1,
    dtype: torch.dtype = torch.float64,
):
    parts_of_data = data.chunk(num_of_split)
    full_data = torch.tensor([], device=data.device, dtype=dtype)
    for i in range(num_of_split, 0, -1):
        full_data = torch.concat((full_data, *parts_of_data[:i]))
    opt_dataset = TensorDataset(full_data)
    opt_sampler = RandomSampler(opt_dataset, replacement=True, num_samples=batch_size * samples)
    return DataLoader(opt_dataset, sampler=opt_sampler, batch_size=batch_size)


def default_discriminator_optimizer(model: Module):
    return Adam(model.parameters(), lr=5.0e-3)


def default_generator_optimizer(model: Module):
    return Adam(model.parameters(), lr=1e-4)


def sample_input_to_generator(
    num_samples: int, dim: int, dtype: torch.dtype = torch.float64, device: int = None
) -> Tensor:
    input_shape = (num_samples, dim * 2)
    return torch.rand(input_shape, device=device, dtype=dtype) * 2 - 1


def bind_space_with_input_mapping(space: BoundedEvaluatedSpace, input_mapping: InputMapping):
    def evaluate(data, *args, **kwargs):
        real_data = space.denormalize(input_mapping.inverse(data))
        return space(real_data, *args, **kwargs)

    return evaluate


def bind_space_with_output_mapping(space: Callable, output_mapping: OutputMapping):
    def evaluate(data, *args, **kwargs):
        values = space(data, *args, **kwargs)
        mapped_values = output_mapping.map(values)
        return mapped_values

    return evaluate


def bind_space_with_mapping(
    space: BoundedEvaluatedSpace,
    input_mapping: InputMapping,
    output_mapping: OutputMapping,
    normalize: bool,
):
    return bind_space_with_output_mapping(
        bind_space_with_input_mapping(space, input_mapping, normalize), output_mapping
    )


def float_range(first: float, second: float = None, skip: float = 1) -> Iterable[int]:
    i = first if second else 0
    max_number = second if second else first
    while i < max_number:
        yield i
        i += skip


def distance_between_tensors(tensor1: Tensor, tensor2: Tensor) -> int:
    return (
        (tensor1 - tensor2)
        .pow(2)
        .sum(dim=(1 if len(tensor1.shape) > 1 or len(tensor2.shape) > 1 else 0))
        .sqrt()
        .item()
    )


def angle_between_tensors(tensor1: Tensor, tensor2: Tensor):
    cos_sim = torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=0)
    acos = torch.acos(cos_sim)
    return torch.rad2deg(acos)
