import math

import hydra
import torch
import torch.nn as nn
from rff.layers import GaussianEncoding

from nn.orf import OrthonormalRandomFeaturesPE


class NeoMLP(nn.Module):
    def __init__(
        self,
        num_inputs,
        num_outputs,
        num_nodes,
        trainable_features,
        requires_grad=True,
        embed_dim=32,
        use_pos_embedding=True,
        pos_embedding_dim=128,
        pos_embedding_sigma=20.0,
        single_input_embedding=False,
        squeeze_output=False,
        signals_to_fit=1,
        shared_hidden_embeddings=True,
        shared_output_embeddings=True,
        shared_head_embeddings=True,
        num_classes=10,
        compressed_embed_dim=0,
        neomlp_attention=None,
        use_orf=False,
        init_sigma=0.001,
        input_init_sigma=1.0,
    ):
        super().__init__()
        self.num_inputs = num_inputs
        self.actual_num_inputs = 1 if single_input_embedding else num_inputs
        self.num_outputs = num_outputs
        self.num_hidden = num_nodes - self.actual_num_inputs - self.num_outputs - 1
        self.idx_output = list(range(num_nodes - num_outputs - 1, num_nodes - 1))
        self.num_nodes = num_nodes
        self.embed_dim = embed_dim
        self.single_input_embedding = single_input_embedding
        self.signals_to_fit = signals_to_fit
        if signals_to_fit > 1 and shared_hidden_embeddings and shared_output_embeddings:
            raise ValueError(
                "Cannot share both hidden and output embeddings when fitting multiple signals at once"
            )
        self.shared_hidden_embeddings = shared_hidden_embeddings
        self.shared_output_embeddings = shared_output_embeddings
        self.shared_head_embeddings = shared_head_embeddings
        self.num_hidden_signals = 1 if shared_hidden_embeddings else signals_to_fit
        self.num_output_signals = 1 if shared_output_embeddings else signals_to_fit
        self.num_head_signals = 1 if shared_head_embeddings else signals_to_fit

        self.compressed_embed_dim = (
            compressed_embed_dim if compressed_embed_dim > 0 else embed_dim
        )

        self._init_sigma = init_sigma
        self._input_init_sigma = input_init_sigma

        self.trainable_features = trainable_features
        if trainable_features:
            self.hidden_embeddings = nn.Parameter(
                self._init_sigma
                * torch.randn(
                    self.num_hidden_signals,
                    self.num_hidden,
                    self.compressed_embed_dim,
                ),
                requires_grad=requires_grad,
            )
        else:
            self.hidden_embeddings = nn.Parameter(
                torch.zeros(
                    self.num_hidden_signals,
                    self.num_hidden,
                    self.compressed_embed_dim,
                ),
                requires_grad=False,
            )
        self.input_embedding = nn.Parameter(
            self._input_init_sigma * torch.randn(1, self.actual_num_inputs, self.embed_dim),
            requires_grad=requires_grad,
        )
        self.output_embedding = nn.Parameter(
            self._init_sigma
            * torch.randn(
                self.num_output_signals, self.num_outputs, self.compressed_embed_dim
            ),
            requires_grad=requires_grad,
        )
        self.simclr_embedding = nn.Parameter(
            self._init_sigma
            * torch.randn(self.num_head_signals, 1, self.compressed_embed_dim),
            requires_grad=requires_grad,
        )

        self.use_orf = use_orf
        if self.use_orf:
            self.orf = OrthonormalRandomFeaturesPE(
                self.actual_num_inputs + self.num_hidden + self.num_outputs,
                self.embed_dim,
            )

        if compressed_embed_dim > 0:
            self.hidden_lin = nn.Sequential(
                nn.Linear(self.compressed_embed_dim, self.embed_dim),
            )
            self.output_lin = nn.Sequential(
                nn.Linear(self.compressed_embed_dim, self.embed_dim),
            )
            self.simclr_lin = nn.Sequential(
                nn.Linear(self.compressed_embed_dim, self.embed_dim),
            )
        else:
            self.hidden_lin = nn.Identity(self.embed_dim)
            self.output_lin = nn.Identity(self.embed_dim)
            self.simclr_lin = nn.Identity(self.embed_dim)

        if use_pos_embedding:
            self.pos_embedding_sigma = pos_embedding_sigma
            self.pos_embedding_dim = pos_embedding_dim

            lin_input_dim = self.num_inputs if single_input_embedding else 1
            self.lin = nn.Sequential(
                GaussianEncoding(
                    input_size=lin_input_dim,
                    encoded_size=pos_embedding_dim // 2,
                    sigma=pos_embedding_sigma,
                ),
                nn.Linear(pos_embedding_dim, self.embed_dim),
            )
        else:
            self.lin = nn.Linear(1, self.embed_dim)

        self.out_linear = nn.Sequential(
            # nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, 1),
        )

        self.neomlp_attention = hydra.utils.instantiate(neomlp_attention)

        self.squeeze_output = squeeze_output

        # self.layer_norm = nn.LayerNorm(self.embed_dim)

        self.reset_parameters()

    def add_cls_head(self, num_classes, single_representation_embedding):
        # Used for representation processing
        self.single_representation_embedding = single_representation_embedding
        self.num_classes = num_classes
        # num_cls_tokens = 1 if single_representation_embedding else num_classes
        cls_out_dim = num_classes if single_representation_embedding else 1

        # self.cls_embedding = nn.Parameter(
        #     torch.randn(1, num_cls_tokens, self.embed_dim),
        #     requires_grad=True,
        # )
        self.mlp_head = nn.Sequential(
            # nn.LayerNorm(self.embed_dim),
            nn.Linear(len(self.idx_output) * self.embed_dim, 4 * self.embed_dim),
            nn.ReLU(),
            nn.Linear(4 * self.embed_dim, cls_out_dim),
        )

    def reset_parameters(self):
        nn.init.zeros_(self.out_linear[-1].bias)
        nn.init.normal_(self.lin[-1].weight, std=math.sqrt(2 / self.pos_embedding_dim))

    def count_parameters(self):
        excluded_params = (
            (["hidden_embeddings"] if not self.shared_hidden_embeddings else [])
            + (["output_embedding"] if not self.shared_output_embeddings else [])
            + (["simclr_embedding"] if not self.shared_head_embeddings else [])
        )
        return sum(
            p.numel()
            for name, p in self.named_parameters()
            if p.requires_grad and name not in excluded_params
        )

    def generate_embeddings(
        self,
        num_signals,
        hidden_mu=None,
        hidden_sigma=None,
        out_mu=None,
        out_sigma=None,
        device=None,
    ):
        hidden_mu = 0.0 if hidden_mu is None else hidden_mu
        hidden_sigma = self._init_sigma if hidden_sigma is None else hidden_sigma
        out_mu = 0.0 if out_mu is None else out_mu
        out_sigma = self._init_sigma if out_sigma is None else out_sigma

        if not self.shared_hidden_embeddings:
            if self.trainable_features:
                hidden_embeddings = nn.Parameter(
                    hidden_mu
                    + hidden_sigma
                    * torch.randn(
                        num_signals,
                        self.num_hidden,
                        self.compressed_embed_dim,
                        device=device,
                    ),
                    requires_grad=True,
                )
            else:
                hidden_embeddings = nn.Parameter(
                    torch.zeros(
                        num_signals,
                        self.num_hidden,
                        self.compressed_embed_dim,
                        device=device,
                    ),
                    requires_grad=True,
                )
        else:
            hidden_embeddings = None
        if not self.shared_output_embeddings:
            output_embeddings = nn.Parameter(
                out_mu
                + out_sigma
                * torch.randn(
                    num_signals,
                    self.num_outputs,
                    self.compressed_embed_dim,
                    device=device,
                ),
                requires_grad=True,
            )
        else:
            output_embeddings = None
        if not self.shared_head_embeddings:
            simclr_embeddings = nn.Parameter(
                self._init_sigma
                * torch.randn(num_signals, 1, self.compressed_embed_dim, device=device),
                requires_grad=True,
            )
        else:
            simclr_embeddings = None
        return hidden_embeddings, output_embeddings, simclr_embeddings

    @torch.no_grad()
    def reset_embeddings(self):
        sigma = 1e-4
        if not self.shared_hidden_embeddings:
            # noise_mu = self.hidden_embeddings.detach().mean()
            noise_std = self.hidden_embeddings.detach().std()
            # self.hidden_embeddings.data = noise_mu + noise_std * torch.randn_like(self.hidden_embeddings)
            # self.hidden_embeddings.data = self.hidden_embeddings.data + noise_std * torch.randn_like(self.hidden_embeddings)
            self.hidden_embeddings.data = (
                self.hidden_embeddings.data
                + noise_std * sigma * torch.randn_like(self.hidden_embeddings)
            )
            # torch.nn.init.normal_(self.hidden_embeddings, noise_mu, noise_std)
            # torch.nn.init.normal_(self.hidden_embeddings, 0.0, sigma)
            print("noise_std", noise_std)
        if not self.shared_output_embeddings:
            # noise_mu = self.output_embedding.detach().mean()
            noise_std = self.output_embedding.detach().std()
            # self.output_embedding.data = noise_mu + noise_std * torch.randn_like(self.output_embedding)
            self.output_embedding.data = (
                self.output_embedding.data
                + noise_std * sigma * torch.randn_like(self.output_embedding)
            )
            # self.output_embedding.data = self.output_embedding.data + noise_std * torch.randn_like(self.output_embedding)
            # torch.nn.init.normal_(self.output_embedding, noise_mu, noise_std)
            # torch.nn.init.normal_(self.output_embedding, 0.0, sigma)
            print("noise_std", noise_std)
        if not self.shared_head_embeddings:
            # noise_mu = self.simclr_embedding.detach().mean()
            noise_std = self.simclr_embedding.detach().std()
            self.simclr_embedding.data = (
                self.simclr_embedding.data
                + noise_std * sigma * torch.randn_like(self.simclr_embedding)
            )
            # torch.nn.init.normal_(self.simclr_embedding, noise_mu, noise_std)
            # torch.nn.init.normal_(self.simclr_embedding, 0.0, sigma)
            print("noise_std", noise_std)

    def _get_embeddings(self):
        hidden_representations = (
            self.hidden_embeddings.detach()
            if not self.shared_hidden_embeddings
            else torch.empty(
                (self.signals_to_fit, 0, self.hidden_embeddings.shape[-1]),
                device=self.hidden_embeddings.device,
            )
        )
        output_representations = (
            self.output_embedding.detach()
            if not self.shared_output_embeddings
            else torch.empty(
                (self.signals_to_fit, 0, self.output_embedding.shape[-1]),
                device=self.output_embedding.device,
            )
        )
        simclr_representations = (
            self.simclr_embedding.detach()
            if not self.shared_head_embeddings
            else torch.empty(
                (self.signals_to_fit, 0, self.simclr_embedding.shape[-1]),
                device=self.simclr_embedding.device,
            )
        )
        return torch.cat(
            [hidden_representations, output_representations, simclr_representations],
            dim=1,
        )

    def _split_embeddings(self, embeddings):
        _num_hidden = 0 if self.shared_hidden_embeddings else self.num_hidden
        _num_output = 0 if self.shared_output_embeddings else self.num_outputs
        hidden_embeddings = embeddings[:, :_num_hidden]
        output_embeddings = embeddings[:, _num_hidden : _num_hidden + _num_output]
        simclr_embeddings = embeddings[:, _num_hidden + _num_output :]
        return hidden_embeddings, output_embeddings, simclr_embeddings

    def _input_linear(self, x):
        g_x = (
            self.lin(x).unsqueeze(-2)
            if self.single_input_embedding
            else self.lin(x.unsqueeze(-1))
        )
        g_x = g_x + self.input_embedding.repeat(g_x.size(0), 1, 1)
        return g_x

    def _hidden_linear(self, hidden_embeddings=None, indices=None, batch_size=None):
        hidden_emb = self.hidden_lin(
            hidden_embeddings
            if hidden_embeddings is not None
            else (
                self.hidden_embeddings[indices]
                if indices is not None and not self.shared_hidden_embeddings
                else self.hidden_embeddings.repeat(batch_size, 1, 1)
            )
        )
        return hidden_emb

    def _output_linear(self, output_embeddings=None, indices=None, batch_size=None):
        out_emb = self.output_lin(
            output_embeddings
            if output_embeddings is not None
            else (
                self.output_embedding[indices]
                if indices is not None and not self.shared_output_embeddings
                else self.output_embedding.repeat(batch_size, 1, 1)
            )
        )
        return out_emb

    def _simclr_linear(self, simclr_embeddings=None, indices=None, batch_size=None):
        simclr_emb = self.simclr_lin(
            simclr_embeddings
            if simclr_embeddings is not None
            else (
                self.simclr_embedding[indices]
                if indices is not None and not self.shared_head_embeddings
                else self.simclr_embedding.repeat(batch_size, 1, 1)
            )
        )
        return simclr_emb

    def _forward(self, input_emb, hidden_emb, out_emb, simclr_emb=None):
        if simclr_emb is None:
            x = torch.cat([input_emb, hidden_emb, out_emb], dim=1)
        else:
            x = torch.cat([input_emb, hidden_emb, out_emb, simclr_emb], dim=1)
        if self.use_orf:
            x = x + self.orf()[None, ...]

        # x = self.layer_norm(x)
        x = self.neomlp_attention(x)

        out = x[:, self.idx_output]
        out = self.out_linear(out)
        out = out.squeeze(-1)
        if self.squeeze_output and out.shape[-1] == 1:
            out = out.squeeze(-1)
        return out

    def forward(self, x, indices=None):
        """
        When `indices` is not None, we are fitting multiple signals at once.
        The hidden or output embeddings (but not both) can be shared across the signals.
        """
        if indices is not None and self.signals_to_fit <= indices.detach().max().item():
            raise ValueError("Index out of range")

        input_emb = self._input_linear(x)
        hidden_emb = self._hidden_linear(indices=indices, batch_size=x.size(0))
        out_emb = self._output_linear(indices=indices, batch_size=x.size(0))
        simclr_emb = self._simclr_linear(indices=indices, batch_size=x.size(0))

        out = self._forward(input_emb, hidden_emb, out_emb, simclr_emb)
        return out

    def forward_embeddings(
        self, x, hidden_embeddings=None, output_embeddings=None, simclr_embeddings=None
    ):
        if hidden_embeddings is None and not self.shared_hidden_embeddings:
            raise ValueError(
                "Must provide embeddings when not sharing hidden embeddings"
            )
        if output_embeddings is None and not self.shared_output_embeddings:
            raise ValueError(
                "Must provide output embeddings when not sharing output embeddings"
            )
        if simclr_embeddings is None and not self.shared_head_embeddings:
            raise ValueError(
                "Must provide simclr embeddings when not sharing simclr embeddings"
            )

        input_emb = self._input_linear(x)
        hidden_emb = self._hidden_linear(
            hidden_embeddings=hidden_embeddings, batch_size=x.size(0)
        )
        out_emb = self._output_linear(
            output_embeddings=output_embeddings, batch_size=x.size(0)
        )
        simclr_emb = self._simclr_linear(
            simclr_embeddings=simclr_embeddings, batch_size=x.size(0)
        )

        out = self._forward(input_emb, hidden_emb, out_emb, simclr_emb)
        return out

    def forward_representations(self, embeddings):
        """
        When `indices` is not None, we are fitting multiple signals at once.
        The hidden or output embeddings (but not both) can be shared across the signals.
        """

        num_signals = embeddings.size(0)
        input_emb = self.input_embedding.repeat(num_signals, 1, 1)

        hidden_emb, out_emb, simclr_emb = self._split_embeddings(embeddings)
        if self.shared_hidden_embeddings:
            hidden_emb = self.hidden_embeddings.repeat(num_signals, 1, 1)
        if self.shared_output_embeddings:
            out_emb = self.output_embedding.repeat(num_signals, 1, 1)
        if self.shared_head_embeddings:
            simclr_emb = self.simclr_embedding.repeat(num_signals, 1, 1)

        hidden_emb = self.hidden_lin(hidden_emb)
        out_emb = self.output_lin(out_emb)
        simclr_emb = self.simclr_lin(simclr_emb)

        # cls_emb = self.cls_embedding.repeat(num_signals, 1, 1)
        x = torch.cat([input_emb, hidden_emb, out_emb, simclr_emb], dim=1)
        if self.use_orf:
            x = x + self.orf()[None, ...]

        x = self.neomlp_attention(x)

        if self.single_representation_embedding:
            out = x[:, self.idx_output].flatten(1)
            # out = x[:, -1]
            out = self.mlp_head(out)
        else:
            out = x[:, -self.num_classes :]
            out = self.mlp_head(out).squeeze(-1)
        return out


class BareNeoMLP(NeoMLP):
    def __init__(
        self,
        num_inputs,
        num_outputs,
        *args,
        **kwargs,
    ):
        super().__init__(
            num_inputs,
            num_outputs,
            num_inputs + num_outputs,
            *args,
            **kwargs,
        )
