import math
from abc import ABC
from typing import List, Dict

import torch
from rtdl_revisiting_models import FTTransformer
from torch import Tensor
from torch.nn import (
    Module,
    Sequential,
    Sigmoid,
    LeakyReLU,
    Linear,
    Tanh,
    ReLU,
    Embedding,
    Conv1d,
    AdaptiveAvgPool1d,
    ConvTranspose1d,
    ConvTranspose2d,
    Conv2d,
    MultiheadAttention,
    GELU,
)

from algorithms.nn.utils import hyperspherical_to_cartesian
from utils.dynamically_load_class import Configurable


class ConfigurableModule(Module, Configurable, ABC):
    pass


class ConstructableModel(Module):
    def model_parameter_tensor(self):
        model_parameters = []
        for layer in self.children():
            for parameter in layer.parameters():
                model_parameters.append(parameter.reshape(-1))
        return torch.cat(model_parameters).clone()

    def from_parameter_tensor(self, parameter_tensor: Tensor) -> "ConstructableModel":
        raise NotImplementedError()


class BaseSequentialModel(ConstructableModel):
    def __init__(self, sequential_model: Sequential):
        super().__init__()
        self.model = sequential_model

    def forward(self, x):
        return self.model(x)

    def from_parameter_tensor(self, parameter_tensor: Tensor) -> "ConstructableModel":
        original_params = next(self.model.parameters())
        parameter_tensor = parameter_tensor.flatten().to(
            device=original_params.device, dtype=original_params.dtype
        )
        new_model = self.model
        i = 0
        for layer in new_model.children():
            for parameter in layer.parameters():
                parameter.data = (
                    parameter_tensor[i : i + parameter.numel()]
                    .clone()
                    .reshape(parameter.shape)
                )
                i += parameter.numel()
        return BaseSequentialModel(new_model)


def ada_in_linear(feature, mean_style, std_style, eps=1e-5):
    B, C, H, W = feature.shape
    # batch_size, input_size = feature.shape

    feature = feature.view(B, C, -1)

    std_feat = (torch.std(feature, dim=2) + eps).view(B, C, 1)
    mean_feat = torch.mean(feature, dim=2).view(B, C, 1)

    adain = std_style * (feature - mean_feat) / std_feat + mean_style

    adain = adain.view(B, C, H, W)
    return adain


class LinearResDown(Module):
    def __init__(self, input_size, output_size=None):
        super(LinearResDown, self).__init__()
        if not output_size:
            output_size = input_size // 2
        self.relu1_left = LeakyReLU()
        self.fc1_left = torch.nn.utils.spectral_norm(Linear(input_size, input_size))
        self.relu2_left = LeakyReLU()
        self.fc2_left = torch.nn.utils.spectral_norm(Linear(input_size, input_size))
        self.relu3_left = LeakyReLU()
        self.fc3_left = torch.nn.utils.spectral_norm(Linear(input_size, output_size))

        self.fc1_right = torch.nn.utils.spectral_norm(Linear(input_size, input_size))
        self.fc2_right = torch.nn.utils.spectral_norm(Linear(input_size, output_size))

    def forward(self, batch):
        left_res = self.relu1_left(batch)
        left_res = self.fc1_left(left_res)
        left_res = self.relu2_left(left_res)
        left_res = self.fc2_left(left_res)
        left_res = self.relu3_left(left_res)
        left_res = self.fc3_left(left_res)

        right_res = self.fc1_right(batch)
        right_res = self.fc2_right(right_res)

        return right_res + left_res


class LinearResUP(Module):
    def __init__(self, input_size, scale: int = 2):
        super(LinearResUP, self).__init__()
        output_size = input_size * scale
        self.upsample_left = torch.nn.utils.spectral_norm(
            Linear(input_size, output_size)
        )
        self.fc_left = torch.nn.utils.spectral_norm(Linear(output_size, output_size))

        self.relu1_right = LeakyReLU()
        self.upsample_right = torch.nn.utils.spectral_norm(
            Linear(input_size, output_size)
        )
        self.fc1_right = torch.nn.utils.spectral_norm(Linear(output_size, output_size))
        self.relu2_right = LeakyReLU()
        self.fc2_right = torch.nn.utils.spectral_norm(Linear(output_size, output_size))

    def forward(self, batch):
        left_res = self.upsample_left(batch)
        left_res = self.fc_left(left_res)

        right_res = self.relu1_right(batch)
        right_res = self.upsample_right(right_res)
        right_res = self.fc1_right(right_res)
        right_res = self.relu2_right(right_res)
        right_res = self.fc2_right(right_res)

        return right_res + left_res


class GeneratorByDimSize(Module):
    def __init__(self, dim_size: int):
        super(GeneratorByDimSize, self).__init__()
        self.l1 = Linear(dim_size * 2, 50)
        self.l2 = LeakyReLU()
        self.l3 = torch.nn.utils.spectral_norm(Linear(50, 50 * dim_size))
        self.l4 = LeakyReLU()
        self.l5 = torch.nn.utils.spectral_norm(Linear(50 * dim_size, dim_size))
        self.l6 = Tanh()

    def forward(self, batch):
        x = self.l1(batch)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        return x


class DiscriminatorByInputSize(Module):
    def __init__(self, dim: int):
        super(DiscriminatorByInputSize, self).__init__()
        self.l1 = Linear(dim, 50)
        self.l2 = LeakyReLU()
        self.l3 = torch.nn.utils.spectral_norm(Linear(50, 50))
        self.l4 = LeakyReLU()
        self.l5 = torch.nn.utils.spectral_norm(Linear(50, 1))
        self.l7 = Sigmoid()

    def forward(self, batch):
        x = self.l1(batch)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l7(x)
        return x


# COPIED CODE ------------------------------------------------------ #
def init_weights(net, init="ortho"):
    net.param_count = 0
    for module in net.modules():
        if isinstance(
            module, (Conv1d, Conv2d, Linear, ConvTranspose2d, ConvTranspose1d)
        ):
            if init == "ortho":
                torch.nn.init.orthogonal_(module.weight)
            elif init == "N02":
                torch.nn.init.normal_(module.weight, 0, 0.02)
            elif init in ["glorot", "xavier"]:
                torch.nn.init.xavier_uniform_(module.weight)
            else:
                print("Init style not recognized...")

        net.param_count += sum([p.data.nelement() for p in module.parameters()])


class ResBlock(Module):
    def __init__(self, layer):
        super(ResBlock, self).__init__()
        self.fc = Sequential(
            ReLU(),
            Linear(layer, layer, bias=True),
            ReLU(),
            Linear(layer, layer, bias=True),
        )

    def forward(self, x):
        h = self.fc(x)
        return x + h


class GlobalModule(Module):
    def __init__(self, planes):
        super(GlobalModule, self).__init__()
        self.emb = 32
        self.blocks = Sequential(GlobalBlock(planes), AdaptiveAvgPool1d(1))

    def forward(self, x):
        x = self.blocks(x)
        x = x.squeeze(2)

        return x


class GlobalBlock(Module):
    def __init__(self, planes):
        super(GlobalBlock, self).__init__()
        self.emb = 32
        self.query = Sequential(
            ReLU(),
            Conv1d(planes, planes, kernel_size=1, padding=0, bias=True),
        )

        self.key = Sequential(
            ReLU(),
            Conv1d(planes, planes, kernel_size=1, padding=0, bias=True),
        )

        self.value = Sequential(
            ReLU(),
            Conv1d(planes, planes, kernel_size=1, padding=0, bias=True),
        )

        self.output = Sequential(
            ReLU(),
            Conv1d(planes, planes, kernel_size=1, padding=0, bias=True),
        )

        self.planes = planes

    def forward(self, x):
        q = self.query(x).transpose(1, 2)
        k = self.key(x)
        v = self.value(x).transpose(1, 2)

        a = torch.softmax(torch.bmm(q, k) / math.sqrt(self.planes), dim=2)
        r = torch.bmm(a, v).transpose(1, 2)
        r = self.output(r)

        return x + r


class SplineNet(Module):
    def __init__(self, input_size: int, output_size: int, device: int):
        super(SplineNet, self).__init__()
        self.embedding = SplineEmbedding(input_size, 10, device)
        self.head = SplineHead(input_size, output_size, 256)

    def forward(self, x, normalize=True):
        if normalize:
            x = torch.tanh(x)

        x = torch.clamp(x, max=1 - 1e-3)

        x_emb = self.embedding(x)
        x = self.head(x, x_emb)

        return x


class SplineEmbedding(Module):
    def __init__(self, input_size: int, delta: int, device):
        super(SplineEmbedding, self).__init__()

        self.delta = delta
        self.input_size = input_size
        self.emb = 32
        self.device = device

        self.ind_offset = (
            torch.arange(self.input_size, dtype=torch.int64).to(device).unsqueeze(0)
        )

        self.b = Embedding(
            (2 * self.delta + 1) * self.input_size, self.emb, sparse=True
        )

    def forward(self, x):
        n = len(x)

        xl = (x * self.delta).floor()
        xli = self.input_size * (xl.long() + self.delta) + self.ind_offset
        xl = xl / self.delta
        xli = xli.view(-1)

        xh = (x * self.delta + 1).floor()
        xhi = self.input_size * (xh.long() + self.delta) + self.ind_offset
        xh = xh / self.delta
        xhi = xhi.view(-1)

        bl = self.b(xli).view(n, self.input_size, self.emb)
        bh = self.b(xhi).view(n, self.input_size, self.emb)

        delta = 1 / self.delta

        x = x.unsqueeze(2)
        xl = xl.unsqueeze(2)
        xh = xh.unsqueeze(2)

        h = bh / delta * (x - xl) + bl / delta * (xh - x)
        return h


class SplineHead(Module):
    def __init__(self, input_size: int, output_size: int, layers: int):
        super(SplineHead, self).__init__()
        self.emb = 32
        self.input_size = input_size
        self.global_interaction = GlobalModule(self.emb)
        input_len = self.emb + self.input_size

        self.fc = Sequential(
            Linear(input_len, layers, bias=True),
            ResBlock(layers),
            ResBlock(layers),
            ReLU(),
            Linear(layers, output_size, bias=True),
        )

        init_weights(self, init="ortho")

    def forward(self, x, x_emb):
        h = x_emb.transpose(2, 1)
        h = self.global_interaction(h)

        x = torch.cat([x, h], dim=1)

        x = self.fc(x)

        return x


class MultipleOptimizer:
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()


class BasicNetwork(ConfigurableModule):
    def __init__(
        self, dims: int, device: int = None, dtype: torch.dtype = torch.float64
    ):
        super().__init__()
        self.network = Sequential(
            Linear(dims, 10),
            ReLU(),
            Linear(10, 15),
            ReLU(),
            Linear(15, dims),
        ).to(device=device, dtype=dtype)

    def forward(self, x):
        return self.network(x)


class SimpleAttention(ConfigurableModule):
    def __init__(
        self, dims: int, attention_head: int = 3, hidden: int = None, device: int = None
    ):
        super().__init__()
        hidden = hidden or attention_head * dims
        self.q = Linear(dims, hidden).to(device=device)
        self.k = Linear(dims, hidden).to(device=device)
        self.v = Linear(dims, hidden).to(device=device)
        self.attention = MultiheadAttention(hidden, attention_head).to(device=device)
        self.output_layer = Linear(hidden, dims).to(device=device)

    def forward(self, data: Tensor) -> Tensor:
        q = self.q(data)
        k = self.k(data)
        v = self.v(data)
        if len(data.shape) == 1:
            q = q.reshape(1, -1)
            k = k.reshape(1, -1)
            v = v.reshape(1, -1)
        data, _ = self.attention(q, k, v)
        return self.output_layer(data)


class SingleFTTransformer(ConfigurableModule):
    def __init__(self, dims: int, device: int = None):
        attention_head: int = 3
        embedding_size: int = 12
        super().__init__()
        self.transformer = FTTransformer(
            n_cont_features=dims,
            cat_cardinalities=[],
            d_block=embedding_size,
            d_out=dims,
            n_blocks=1,
            attention_n_heads=attention_head,
            attention_dropout=0.0,
            ffn_d_hidden_multiplier=2.0,
            ffn_dropout=0.0,
            residual_dropout=0.0,
        ).to(device=device)

    def forward(self, data: Tensor) -> Tensor:
        original_shape = data.shape
        if single_dim := len(original_shape) == 1:
            data = data.reshape(1, -1)
        data = self.transformer(data, None)
        return data.reshape(original_shape) if single_dim else data


class BigLinearNetwork(ConfigurableModule):
    def __init__(self, dims: int, layers: List[int] = None, device: int = None):
        super().__init__()
        self.network = Sequential(
            *[
                Sequential(Linear(input_dim, output_dim), GELU())
                for input_dim, output_dim in zip([dims] + layers, layers + [dims])
            ]
        )
        self.network = self.network.to(device=device)

    def forward(self, data: Tensor):
        return self.network(data)

    @classmethod
    def object_default_values(cls) -> Dict:
        return {"layers": [50, 100, 100, 100, 70, 50]}


class BasicHessianNetwork(ConfigurableModule):
    def __init__(
        self, dims: int, device: int = None, dtype: torch.dtype = torch.float64
    ):
        super().__init__()
        self.network = Sequential(
            Linear(dims, 10),
            ReLU(),
            Linear(10, 15),
            ReLU(),
            Linear(15, dims**2),
        ).to(device=device, dtype=dtype)
        self.dims = dims

    def forward(self, x):
        output_shape = (
            (x.shape[0], self.dims, self.dims)
            if len(x.shape) > 1
            else (self.dims, self.dims)
        )
        return self.network(x).reshape(output_shape)


class HessianFTTransformer(ConfigurableModule):
    def __init__(
        self,
        dims: int,
        embedding_size: int = 12,
        attention_head: int = 3,
        device: int = None,
    ):
        super().__init__()
        self.transformer = FTTransformer(
            n_cont_features=dims,
            cat_cardinalities=[],
            d_block=embedding_size,
            d_out=dims**2,
            n_blocks=1,
            attention_n_heads=attention_head,
            attention_dropout=0.1,
            ffn_d_hidden_multiplier=2.0,
            ffn_dropout=0.1,
            residual_dropout=0.0,
        ).to(device=device)
        self.dims = dims

    def forward(self, data: Tensor) -> Tensor:
        original_shape = data.shape
        if single_dim := len(original_shape) == 1:
            data = data.reshape(1, -1)
        data = self.transformer(data, None)
        return (
            data.reshape(self.dims, self.dims)
            if single_dim
            else data.reshape(original_shape[0], self.dims, self.dims)
        )


class PolarGradient(ConfigurableModule):
    def __init__(self, dims: int, device: int = None):
        super().__init__()
        self.radius = Sequential(
            Linear(dims, 10),
            ReLU(),
            Linear(10, 15),
            ReLU(),
            Linear(15, 1),
        ).to(device=device)
        self.angles = Sequential(
            Linear(dims, 10),
            ReLU(),
            Linear(10, 15),
            ReLU(),
            Linear(15, dims - 1),
        ).to(device=device)

    def forward(self, x):
        radius = self.radius(x)
        angles = self.angles(x)
        angles = (torch.sigmoid(angles) - 0.5) * torch.pi
        polar_coordinates = torch.cat([radius, angles], dim=-1)
        return hyperspherical_to_cartesian(polar_coordinates)


class ModelToTrain(BaseSequentialModel):
    def __init__(self, device: int, dtype: int, dims: int):
        model = Sequential(Linear(dims, 1, bias=False, dtype=dtype)).to(device=device)
        super().__init__(sequential_model=model)


class BasicSurrogateModel(ConfigurableModule):
    def __init__(self, dims: int, device: int = None):
        super().__init__()
        self.network = Sequential(
            Linear(dims, 10),
            ReLU(),
            Linear(10, 15),
            ReLU(),
            Linear(15, 1),
        ).to(device=device)

    def forward(self, x):
        return self.network(x)
