import abc
import math

import numpy as np
import torch
from torch import Tensor

from algorithms.space.basic_callable_space import BasicCallableSpace
from algorithms.space.utils import is_multiple_points
from problems.types import Suites
from utils.python import generate_orthonormal_matrix


def unif(N, inseed):
    """Generates N uniform numbers with starting seed."""

    # initialization
    inseed = np.abs(inseed)
    if inseed < 1.0:
        inseed = 1.0

    rgrand = 32 * [0.0]
    aktseed = inseed
    for i in range(39, -1, -1):
        tmp = math.floor(aktseed / 127773.0)
        aktseed = 16807.0 * (aktseed - tmp * 127773.0) - 2836.0 * tmp
        if aktseed < 0:
            aktseed = aktseed + 2147483647.0
        if i < 32:
            rgrand[i] = aktseed
    aktrand = rgrand[0]

    # sample numbers
    r = int(N) * [0.0]
    for i in range(int(N)):
        tmp = math.floor(aktseed / 127773.0)
        aktseed = 16807.0 * (aktseed - tmp * 127773.0) - 2836.0 * tmp
        if aktseed < 0:
            aktseed = aktseed + 2147483647.0
        tmp = int(math.floor(aktrand / 67108865.0))
        aktrand = rgrand[tmp]
        rgrand[tmp] = aktseed
        r[i] = aktrand / 2.147483647e9
    r = np.asarray(r)
    if (r == 0).any():
        r[r == 0] = 1e-99
    return r


def gauss(N, seed):
    """Samples N standard normally distributed numbers
    being the same for a given seed

    """
    r = unif(2 * N, seed)
    g = np.sqrt(-2 * np.log(r[:N])) * np.cos(2 * np.pi * r[N : 2 * N])
    if np.any(g == 0.0):
        g[g == 0] = 1e-99
    return g


def compute_rotation(seed, dim):
    original_seed = torch.seed()
    torch.manual_seed(seed)
    B = torch.normal(mean=0, std=1, size=(dim, dim))
    torch.manual_seed(original_seed)
    B = np.reshape(gauss(dim * dim, seed), (dim, dim))
    for i in range(dim):
        for j in range(0, i):
            B[i] = B[i] - np.dot(B[i], B[j]) * B[j]
        B[i] = B[i] / (np.sum(B[i] ** 2) ** 0.5)
    return B


class TensorSpace(BasicCallableSpace):
    def __init__(
        self,
        m: Tensor,
        b: Tensor,
        c: Tensor,
        input_lower_bounds: Tensor,
        input_upper_bounds: Tensor,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.m = m
        self.b = b
        self.c = c
        self.input_upper_bounds = input_upper_bounds
        self.input_lower_bounds = input_lower_bounds

    def is_goal_reached(self):
        return False

    @property
    def device(self):
        return self.m.device

    @abc.abstractmethod
    def tensor_func(self, data):
        raise NotImplementedError()

    @property
    def callable_env(self):
        def func(data):
            return self.tensor_func(data)

        return func

    @property
    def suite(self) -> Suites:
        return Suites.EXTERNAL

    @property
    def func_id(self) -> int:
        return 1

    @property
    def func_instance(self) -> int:
        return 0

    def __repr__(self):
        return f"{self.__class__.__name__}-{self.dimension}"

    def __str__(self):
        return f"{self.__repr__()}, remaining budget: {self.num_of_samples}"

    @property
    def upper_bound(self):
        return self.input_upper_bounds

    @property
    def lower_bound(self):
        return self.input_lower_bounds

    def g_func(self, x):
        grad_calc = (
            torch.vmap(torch.func.grad(self.callable_env))
            if len(x.shape) > 1
            else torch.func.grad(self.callable_env)
        )
        return grad_calc(x)

    def h_func(self, x):
        hessian_calc = (
            torch.vmap(torch.func.hessian(self.callable_env))
            if len(x.shape) > 1
            else torch.func.hessian(self.callable_env)
        )
        return hessian_calc(x)


class MatrixSpace(TensorSpace):
    def tensor_func(self, data):
        if len(data.shape) > 1:
            pol_2nd = ((data @ self.m) * data).sum(dim=1) + (self.b * data).sum(dim=1)
        else:
            pol_2nd = data @ self.m @ data + (self.b * data).sum()
        return pol_2nd + self.c


class EggHolderTensorSpace(TensorSpace):
    def tensor_func(self, data):
        data = data * 100
        x1 = data[:, 0::2] if len(data.shape) > 1 else data[0::2]
        x2 = data[:, 1::2] if len(data.shape) > 1 else data[1::2]

        term1 = -(x2 + 47) * torch.sin(torch.sqrt(torch.abs(x2 + x1 / 2 + 47)))
        term2 = -x1 * torch.sin(torch.sqrt(torch.abs(x1 - (x2 + 47))))

        total = (term1 + term2).sum(dim=1 if len(data.shape) > 2 else 0)
        return total


class CocoTensorSpace(TensorSpace):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.q = generate_orthonormal_matrix(self.dimension, self.m.dtype, self.device)
        self.r = generate_orthonormal_matrix(self.dimension, self.m.dtype, self.device)

    def big_lambda(self, data, alpha: int):
        dim = data.shape[-1]
        return torch.eye(dim, device=self.device, dtype=data.dtype) * (
            alpha ** (0.5 * torch.linspace(0, 1, dim, device=data.device))
        )

    def t_asy(self, data, beta: float):
        power = (
            1
            + (beta * torch.linspace(0, 1, data.shape[-1], device=data.device))
            * data.sqrt()
        )
        positive_values_transformed = data**power
        return torch.where(data > 0, positive_values_transformed, data)

    def t_osz(self, data):
        x_hat = torch.where(data < 0, torch.log(-data), torch.log(data)) * 10
        x_sign = torch.sign(data)
        c1 = torch.where(
            data > 0,
            torch.ones_like(data, device=self.device),
            torch.ones_like(data, device=self.device) * 0.55,
        )
        c2 = torch.where(
            data > 0,
            torch.ones_like(data, device=self.device) * 0.79,
            torch.ones_like(data, device=self.device) * 0.31,
        )
        x = x_sign * (
            torch.exp(x_hat + 0.49 * (torch.sin(c1 * x_hat) + torch.sin(c2 * x_hat)))
            ** 0.1
        )
        return torch.where(data != 0, x, data)

    def tensor_func(self, data):
        data = self.data_transformation(data)
        # data = data + self.c
        return self._tensor_func(data)

    def data_transformation(self, data):
        return data @ self.m + self.b

    @abc.abstractmethod
    def _tensor_func(self, data):
        raise NotImplementedError()


class EllipsoidSpace(CocoTensorSpace):
    def data_transformation(self, data):
        return (data * self.b) - self.c

    def _tensor_func(self, data):
        return (data**2).sum(dim=1 if is_multiple_points(data) else 0)


class SharpRidgeSpace(CocoTensorSpace):
    def _tensor_func(self, data):
        x1_square = data[:, 0] ** 2 if len(data.shape) > 1 else data[0] ** 2
        x_rest_square = (
            (data[:, 1:] ** 2).sum(dim=1).sqrt(dim=1)
            if len(data.shape) > 1
            else (data[1:] ** 2).sum().sqrt()
        )
        return x1_square + 100 * x_rest_square


class RastriginSpace(CocoTensorSpace):
    def data_transformation(self, data):
        return (
            self.r
            @ self.big_lambda(data, 10)
            @ self.q
            @ self.t_asy(self.t_osz(data - self.b), 0.2).T
        ).T

    def _tensor_func(self, data):
        axis = 1 if is_multiple_points(data) else 0
        return (
            10 * (data.shape[-1] - (torch.cos(2 * torch.pi * data).sum(dim=axis)))
            + torch.norm(data, dim=axis) ** 2
            + self.c
        )


class StepEllipsoidSpace(CocoTensorSpace):
    def data_transformation(self, data):
        manipulation_matrix = self.big_lambda(data, 10) @ self.r
        # manipulation_matrix = self.r
        return (manipulation_matrix @ (data - self.b).T).T

    def _tensor_func(self, data):
        z_telda_mask = data.abs() > 0.5
        z_telda_less_than_point5 = torch.floor(data) * z_telda_mask
        z_telda_greater_than_point5 = (torch.floor(10 * data) / 10) * (~z_telda_mask)
        z_telda = z_telda_less_than_point5 + z_telda_greater_than_point5
        z = z_telda @ self.q
        axis = 1 if is_multiple_points(data) else 0
        sum_multipliers_power = 10 ** torch.linspace(
            0, 1, data.shape[-1], device=data.device
        )
        possible_values = torch.stack(
            (
                data[..., 0].abs() * 1e-4,
                ((z**2) * sum_multipliers_power).sum(dim=axis),
            ),
            dim=axis,
        )
        value = torch.max(possible_values, dim=axis).values
        return 0.1 * value + self.c


class EllipsoidalSpace(CocoTensorSpace):
    def data_transformation(self, data):
        return self.t_osz(data - self.b)

    def _tensor_func(self, data):
        sum_multipliers_power = 10 ** torch.linspace(
            0, 1, data.shape[-1], device=data.device
        )
        return (data**2) @ sum_multipliers_power + self.c


class BucheRastriginSpace(RastriginSpace):
    def data_transformation(self, data):
        s = 10 ** (0.5 * torch.linspace(0, 1, data.shape[-1], device=data.device))
        s[1::2] *= 10
        return s * self.t_osz(data - self.b)


class SchaffersF7Space(CocoTensorSpace):
    def data_transformation(self, data):
        z = (
            self.big_lambda(data, 10)
            @ self.q
            @ self.t_asy(self.r @ (data - self.b).T, 0.5)
        ).T
        s = z[..., :-1] ** 2 + z[..., 1:] ** 2
        return s

    def _tensor_func(self, data):
        data_sum = (data**0.25 * (torch.sin(50 * data**0.1) ** 2 + 1)).mean(dim=-1)
        return data_sum**2 + self.c


class LinearSpace(CocoTensorSpace):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.b = (
            torch.ones_like(self.b)
            * 5
            * (torch.randint(2, self.b.shape, device=self.b.device) * 2 - 1)
        )

    def _tensor_func(self, data):
        s = torch.sign(self.b) * (
            10 ** torch.linspace(0, 1, data.shape[-1], device=data.device)
        )
        return 5 * s.abs().sum() + data @ s + self.c
