import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import os
from typing import Optional, Type, Callable, Union
from src.util.tb_log import TBWritter
from src.util.metric import Metric
from src.util.utils import EarlyStopping, DataLoaderSampler
import sys
import math


class MLP(nn.Sequential):
    r"""Multilayer perceptron.

    Args:
        input_size: size of each input sample
        hidden_sizes: sizes of each modules, the last one is the size of the MLP layer
        act_class: the class of the activation function
        out_act: whether applying the activation function on the output.

    Inputs: input
        - input: a Tensor with shape [batch_size, input_size]

    Outputs: output
        - output: a Tensor with shape [batch_size, hidden_sizes[-1]]
    """

    def __init__(
            self,
            input_size: int,
            *hidden_sizes: int,
            act_class: Optional[Type[nn.Module]] = None,
            out_act: Union[bool, Callable] = False,
            dropout: Optional[float] = None,
            linear_class: nn.Module = nn.Linear,
    ):
        r"""Initialize a MLP layer

        Args:
            input_size: size of each input sample
            hidden_sizes: sizes of each modules, the last one is the size of the MLP layer
            act_class: the class of the activation function
            out_act: whether applying the activation function on the output.
        """
        layers = []
        hidden_sizes = [input_size] + list(hidden_sizes)
        for i in range(1, len(hidden_sizes)):
            layers.append(linear_class(hidden_sizes[i - 1], hidden_sizes[i]))
            if (act_class is not None) and (i < len(hidden_sizes) - 1):
                layers.append(act_class())
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
        if isinstance(out_act, bool):
            if out_act:
                layers.append(act_class())
        elif callable(out_act):
            layers.append(out_act())
        else:
            raise ValueError('`out_act` must be a bool value or a callable function.')
        super().__init__(*layers)


class EmbeddingNodeAttention(nn.Module):
    def __init__(self, k, in_dim, embedding_dims, attn_dropout=0.1):
        super().__init__()
        self.k = k
        self.in_dim = in_dim
        self.embedding_dims = embedding_dims

        self.temperature = embedding_dims ** 0.5
        self.dropout = nn.Dropout(attn_dropout)

        self.qw = nn.Linear(self.embedding_dims, self.embedding_dims, bias=False)
        self.kw = nn.Linear(self.embedding_dims, self.embedding_dims, bias=False)
        self.vw = nn.Linear(self.embedding_dims, self.embedding_dims, bias=False)

        self.embedding_nlp = EmbeddingMLP(k, in_dim, embedding_dims)

    def forward(self, correlated_embeddings):
        # # embeddings: b_s * embedding_size * k
        # return: b_s * in_dim * embedding_dims

        # b_s * k * embedding_size
        correlated_embeddings = correlated_embeddings.transpose(1, 2).reshape(-1, self.embedding_dims)

        # b_s * k * embedding_size
        q = self.qw(correlated_embeddings).reshape(-1, self.k, self.embedding_dims)
        k = self.kw(correlated_embeddings).reshape(-1, self.k, self.embedding_dims)
        v = self.vw(correlated_embeddings).reshape(-1, self.k, self.embedding_dims)
        # b_s * k * k
        attn = torch.bmm(q / self.temperature, k.transpose(1, 2))
        attn = self.dropout(F.softmax(attn, dim=-1))

        # b_s * k * embedding_dims
        output = torch.bmm(attn, v)

        return self.embedding_nlp(output.transpose(1, 2))


class EmbeddingMLP(nn.Module):
    def __init__(self, k, in_dim, embedding_dims):
        super().__init__()
        self.k = k
        self.in_dim = in_dim
        self.embedding_dims = embedding_dims
        self.learner = MLP(k, in_dim)
        self.relu = nn.LeakyReLU()

    def forward(self, correlated_embeddings):
        correlated_embeddings = correlated_embeddings.transpose(1, 2)
        # out: b_s * embedding_size * k
        correlated_embeddings = correlated_embeddings.reshape(-1, self.k)
        # out: (b_s * embedding_size) * k

        hs = self.learner(correlated_embeddings).reshape(-1, self.embedding_dims, self.in_dim)
        # b_s * embedding_dims * in_dim
        hs = hs.transpose(1, 2)
        # b_s * in_dim * embedding_dims
        return self.relu(hs)


class MetaLinear(nn.Module):
    def __init__(self, in_dim, out_dim, embedding_dim, k, n, generator_hidden=(), embedding_learner=EmbeddingMLP):
        super(MetaLinear, self).__init__()

        self._in_dim = in_dim
        self._out_dim = out_dim
        self._embedding_dims = embedding_dim
        self.k = k
        self.n = n

        self.embeddings = nn.Parameter(torch.ones([n, embedding_dim]), requires_grad=True)
        # self.generator = MLP(*(embedding_dim, *generator_hidden, in_dim * out_dim), act_class=nn.LeakyReLU)

        self.embedding_learner = embedding_learner(k, in_dim, embedding_dim)
        self.weight_learner = MLP(*(embedding_dim, *generator_hidden, out_dim), act_class=nn.LeakyReLU)

        self.bias = nn.Parameter(torch.ones([out_dim]), requires_grad=True)

        self.reset_parameters()

    def forward_selects(self, x, selects):
        assert len(x) == len(selects)
        return self.forward_batch(x, selects, 1)

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size
        assert len(selects[0]) == self.k

        _sample_size, _dim = x.shape
        # b_s * k * embedding_size
        correlated_embeddings = torch.stack([self.embeddings[select] for select in selects])
        # b_s * in_dim * embedding_size
        hs = self.embedding_learner(correlated_embeddings)

        _weights = self.weight_learner(hs.reshape(-1, self._embedding_dims)).reshape(len(selects),
                                                                                     self._in_dim, self._out_dim)

        return torch.bmm(x.reshape(len(selects), sample_batch_size, _dim),
                         _weights).reshape(_sample_size, self._out_dim) + self.bias

    def forward(self, x, select):
        return self.forward_batch(x, [select], len(x))

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.embeddings, a=math.sqrt(5))
        if self.bias is not None:
            bound = 1 / math.sqrt(self._out_dim)
            nn.init.uniform_(self.bias, -bound, bound)


class MetaMLP(nn.Module):
    def __init__(self, dims: list, embedding_dim, k, n, generator_hidden=(), embedding_learner=EmbeddingMLP):
        super(MetaMLP, self).__init__()

        self.k = k
        self.n = n
        self._dims = dims
        self._embedding_dim = embedding_dim
        self._generator_hidden = generator_hidden
        self.embedding_learner = embedding_learner

        self._mlp = nn.ModuleList()
        for i in range(len(self._dims) - 1):
            self._mlp += [MetaLinear(self._dims[i], self._dims[i + 1],
                                     self._embedding_dim, self.k, self.n, self._generator_hidden,
                                     self.embedding_learner)]

    def forward(self, x, select):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward(x, select)
        return x

    def forward_selects(self, x, selects):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward_selects(x, selects)
        return x

    def forward_batch(self, x, selects, sample_batch_size: int):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward_batch(x, selects, sample_batch_size)
        return x


class MetaFE(nn.Module):
    def __init__(self, dims, embedding_dim, k, n, generator_hidden=(), embedding_learner=EmbeddingMLP):
        super(MetaFE, self).__init__()
        self._dims = dims
        self._embedding_dim = embedding_dim
        self._generator_hidden = generator_hidden
        self.k = k
        self.n = n
        self.embedding_learner = embedding_learner

        self._meta_mlp = MetaMLP(self._dims, self._embedding_dim, self.k, self.n,
                                 self._generator_hidden, self.embedding_learner)

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size

        _x = [x[i*sample_batch_size: (i+1)*sample_batch_size, select] for i, select in enumerate(selects)]
        _x = torch.cat(_x)

        _x = self._meta_mlp.forward_batch(_x, selects, sample_batch_size)
        return _x

    def forward(self, x, select):
        return self.forward_batch(x, [select], len(x))

    def forward_selects(self, x, selects):
        return self.forward_batch(x, selects, 1)


class NormalFE(nn.Module):
    def __init__(self, dims, embedding_dim=None, generator_hidden=(), embedding_module=None,
                 embedding_module_param=None):
        super(NormalFE, self).__init__()
        self._dims = dims
        self.embedding_dim = embedding_dim
        self.generator_hidden = generator_hidden
        self._mlp = MLP(*self._dims)

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size
        _x = torch.zeros_like(x)
        for i, select in enumerate(selects):
            _x[i*sample_batch_size: (i+1)*sample_batch_size, select] = \
                x[i*sample_batch_size: (i+1)*sample_batch_size, select]
        return self._mlp(x)

    def forward(self, x, select):
        return self.forward_batch(x,  [select], len(x))

    def forward_selects(self, x, selects):
        return self.forward_batch(x, selects, 1)


if __name__ == "__main__":
    a = MetaLinear(10, 20, 8, 5, 500, (32,))
    # 22.2 ms ± 322 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 10), list(range(5))) for i in range(128)]
    # 3.86 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_selects(torch.ones(128, 10), [list(range(5)) for i in range(128)])
    # 1.94 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 10), [list(range(5)) for i in range(64)], 2)
    # 193 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    a.forward(torch.ones(128, 10), list(range(5)))

if __name__ == "__main__":
    a = MetaMLP((5, 32, 64, 2), 32, 5, 100, embedding_learner=EmbeddingNodeAttention)
    # 100 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 5), list(range(5))) for i in range(128)]
    # 17.3 ms ± 761 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_selects(torch.ones(128, 5), [list(range(5)) for i in range(128)])
    # 9.31 ms ± 674 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 5), [list(range(5)) for i in range(64)], 2)
    # 934 µs ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    a.forward(torch.ones(128, 5), list(range(5)))

if __name__ == "__main__":
    a = MetaFE((5, 32, 64, 2), 32, 5, 100, embedding_learner=EmbeddingNodeAttention)
    # 128 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 5), list(range(5))) for i in range(128)]
    # 24.4 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    a.forward_selects(torch.ones(128, 5),  [list(range(5)) for i in range(128)])
    # 12.5 ms ± 270 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 5),  [list(range(5)) for i in range(64)], 2)
    # 1.28 ms ± 62.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    a.forward(torch.ones(128, 5),  list(range(5)))
