import torch
from torch import nn as nn
from torch.nn import functional as F

from margflow.nn.mlp import MLP, FourierFeatureLayer


class MLPNoWeights(nn.Module):
    def __init__(self, in_dim, hid_dim, n_layers, out_dim, skip_connection=False):
        super(MLPNoWeights, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.out_dim = out_dim
        self.skip_connection = skip_connection

        self._build_network()

    def _build_network(self):
        self.layer_shapes = [
            (self.hid_dim, self.in_dim),  # Output weight
            (self.hid_dim,),
        ]  # Shape for bias
        self.layer_shapes += [(self.hid_dim, self.hid_dim), (self.hid_dim,)] * self.n_layers
        self.layer_shapes += [(self.out_dim, self.hid_dim), (self.out_dim,)]

    def forward(self, x, params):
        # X needs to be of the shape [mb_x, mb_c, d]
        current_param_idx = 0
        n_forward_passes = len(self.layer_shapes) // 2
        x = x[..., None]
        for i, (weight_shape, bias_shape) in enumerate(
            zip(self.layer_shapes[::2], self.layer_shapes[1::2])
        ):
            weight = params[current_param_idx].view(x.shape[0], 1, *weight_shape)
            bias = params[current_param_idx + 1].view(x.shape[0], 1, *bias_shape, 1)
            current_param_idx += 2

            x = weight @ x + bias  # build linear layer manually

            # store f(x) after first forward pass
            if self.skip_connection and i == 0:
                skip_connection_input = x
            # add skip connection to last hidden layer
            if i == n_forward_passes - 2 and self.skip_connection:
                x = x + skip_connection_input
            # apply non-linear activation after all forward passes except last one
            if i < n_forward_passes - 1:
                x = F.relu(x)

        return x.squeeze(dim=-1)


class Hypernetwork(nn.Module):
    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        n_layers: int,
        out_dim: int,
        hypernet_in_dim: int,
        hypernet_hid_dim: int,
        hypernet_n_layers: int,
        fourier_dim: int = 256,
        dropout: float = 0.0,
        skip_connection: bool = False,
    ):
        super(Hypernetwork, self).__init__()

        self.network = MLPNoWeights(
            in_dim=in_dim,
            hid_dim=hid_dim,
            n_layers=n_layers,
            out_dim=out_dim,
            skip_connection=skip_connection,
        )

        n_params = sum(torch.prod(torch.tensor(shape)) for shape in self.network.layer_shapes)

        self.fourier_layer = FourierFeatureLayer(hypernet_in_dim, fourier_dim, scale=1)
        self.embed_cond = nn.Linear(2 * fourier_dim, hypernet_hid_dim)
        self.hyper_network = MLP(
            in_dim=hypernet_hid_dim,
            hid_dim=hypernet_hid_dim,
            n_layers=hypernet_n_layers,
            out_dim=n_params,
            dropout=dropout,
        )

    def forward(self, x, context):
        # Generate network parameters using the hyper_network
        context_fourier = self.fourier_layer(context)
        context_embed = self.embed_cond(context_fourier)
        flat_params = self.hyper_network(context_embed)
        # Split the flat parameters into tensors for each layer of network
        params = []
        offset = 0
        for shape in self.network.layer_shapes:
            param_size = torch.prod(torch.tensor(shape)).item()
            params.append(flat_params[:, offset : offset + param_size])
            offset += param_size
        return self.network(x, params)


class CondFourier(nn.Module):
    def __init__(
        self,
        in_dim,
        cond_dim,
        out_dim,
        hidden_dims=[256, 256, 256],
        fourier_dim=256,
        fourier_sigma=0.01,
        fourier_dim_x=None,
        dropout=0.0,
    ):
        super().__init__()
        # Note: fourier_sigma can have a significant impact in the convergence rate and generalization behaviour
        self.dropout_layer = torch.nn.Dropout(p=dropout)
        self.fourier_dim_x = fourier_dim_x
        if fourier_dim_x is not None:
            assert isinstance(fourier_dim_x, int)
            fourier_sigma_x = 1  # this should be fine-tuned depending on the task
            self.fourier_layer_x = FourierFeatureLayer(
                in_dim, fourier_dim_x, scale=fourier_sigma_x
            )
            self.embed_x = nn.Linear(
                2 * fourier_dim_x + in_dim, hidden_dims[0]
            )  # Collapse Fourier features

        self.cond_dim_post_fourier = hidden_dims[0] // 3
        self.fourier_layer_cond = FourierFeatureLayer(cond_dim, fourier_dim, scale=fourier_sigma)
        self.embed_cond = nn.Linear(
            2 * fourier_dim, self.cond_dim_post_fourier
        )  # Collapse Fourier features

        # Define MLP layers
        self.layers = nn.ModuleList()
        prev_dim = hidden_dims[0] if fourier_dim_x is not None else in_dim
        for i, h_dim in enumerate(hidden_dims):
            self.layers.append(nn.Linear(prev_dim, h_dim))
            # Concatenating condition at each layer
            prev_dim = h_dim + self.cond_dim_post_fourier if i < len(hidden_dims) - 1 else h_dim

        self.output_layer = nn.Linear(prev_dim, out_dim)

    def forward(self, x, context):
        """
        Forward pass of the conditional MLP.
        Args:
            x: Input tensor of shape (batch_size, input_dim).
            context: Condition tensor of shape (batch_size, condition_dim).
        Returns:
            Output tensor of shape (batch_size, 1).
        """
        # Process condition c

        if self.fourier_dim_x is not None:
            x_fourier = self.fourier_layer_x(x)
            x_fourier = torch.cat([x, x_fourier], dim=-1)
            x = self.embed_x(x_fourier)
        c_fourier = self.fourier_layer_cond(context)  # (batch_size, 2 * fourier_dim)
        c_embed = self.embed_cond(c_fourier)  # (batch_size, hidden_dim)

        # Forward through MLP with condition appended
        # x = torch.cat([x, c_embed], dim=-1)  # Concatenate at input
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x))
            if i < len(self.layers) - 1:
                expanded_c = c_embed[:, None].expand(x.shape[0], x.shape[1], -1)
                x = torch.cat([x, expanded_c], dim=-1)
            x = self.dropout_layer(x)

        return self.output_layer(x)
