import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import os
from typing import Optional, Type, Callable, Union
from src.util.tb_log import TBWritter
from src.util.metric import Metric
from src.util.utils import EarlyStopping, DataLoaderSampler
import sys
import math


class MLP(nn.Sequential):
    r"""Multilayer perceptron.

    Args:
        input_size: size of each input sample
        hidden_sizes: sizes of each modules, the last one is the size of the MLP layer
        act_class: the class of the activation function
        out_act: whether applying the activation function on the output.

    Inputs: input
        - input: a Tensor with shape [batch_size, input_size]

    Outputs: output
        - output: a Tensor with shape [batch_size, hidden_sizes[-1]]
    """

    def __init__(
            self,
            input_size: int,
            *hidden_sizes: int,
            act_class: Optional[Type[nn.Module]] = None,
            out_act: Union[bool, Callable] = False,
            dropout: Optional[float] = None,
            linear_class: nn.Module = nn.Linear,
    ):
        r"""Initialize a MLP layer

        Args:
            input_size: size of each input sample
            hidden_sizes: sizes of each modules, the last one is the size of the MLP layer
            act_class: the class of the activation function
            out_act: whether applying the activation function on the output.
        """
        layers = []
        hidden_sizes = [input_size] + list(hidden_sizes)
        for i in range(1, len(hidden_sizes)):
            layers.append(linear_class(hidden_sizes[i - 1], hidden_sizes[i]))
            if (act_class is not None) and (i < len(hidden_sizes) - 1):
                layers.append(act_class())
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
        if isinstance(out_act, bool):
            if out_act:
                layers.append(act_class())
        elif callable(out_act):
            layers.append(out_act())
        else:
            raise ValueError('`out_act` must be a bool value or a callable function.')
        super().__init__(*layers)


class EmbeddingAttention(nn.Module):
    def __init__(self, n, embedding_dim, attn_dropout=0.1):
        super().__init__()
        self.temperature = embedding_dim ** 0.5
        self.dropout = nn.Dropout(attn_dropout)
        self.embedding_dim = embedding_dim

        self.embeddings_k = nn.Parameter(torch.ones([n, embedding_dim]), requires_grad=True)
        self.embedding_trans = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.v_trans = nn.Linear(embedding_dim, embedding_dim, bias=False)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.embeddings_k, a=math.sqrt(5))

    def forward(self, embeddings, subsets):
        # # embeddings: batch * k * embedding_dim
        # subsets: [[...],[...],]
        # return: batch * k          eg: [[0.1, 0.9], [0.2, 0.8]]
        batch_size, k, _ = embeddings.shape
        qs = self.embedding_trans(embeddings.view(-1, self.embedding_dim)).view(batch_size, k, self.embedding_dim)
        vs = self.v_trans(embeddings.view(-1, self.embedding_dim)).view(batch_size, k, self.embedding_dim)
        # batch * k * embedding_dim
        ks = torch.stack([self.embeddings_k[select] for select in subsets])
        # batch * k * k
        attn = torch.bmm(qs / self.temperature, ks.transpose(1, 2))
        attn = self.dropout(F.softmax(attn, dim=-1))
        # batch * k * dim
        output = torch.bmm(attn, vs)
        return output.mean(1)


class MetaLinear(nn.Module):
    def __init__(self, in_dim, out_dim, embedding_dim, n, generator_hidden=(), embedding_module=None,
                 embedding_module_param=None):
        super(MetaLinear, self).__init__()
        if embedding_module_param is None:
            embedding_module_param = {}
        self._in_dim = in_dim
        self._out_dim = out_dim
        self._embedding_dims = embedding_dim

        self.embeddings = nn.Parameter(torch.ones([n, embedding_dim]), requires_grad=True)
        self.generator = MLP(*(embedding_dim, *generator_hidden, in_dim * out_dim), act_class=nn.LeakyReLU)

        self.bias = nn.Parameter(torch.ones([out_dim]), requires_grad=True)

        self.embedding_module = None
        if embedding_module is not None:
            embedding_module_param = {} if embedding_module_param is None else embedding_module_param
            self.embedding_module = embedding_module(n, embedding_dim, **embedding_module_param)

        self.reset_parameters()

    def forward_selects(self, x, selects):
        assert len(x) == len(selects)
        return self.forward_batch(x, selects, 1)
        # hs = torch.stack([self.embeddings[select] for select in selects]).mean(1)
        # _weights = self.generator(hs).reshape(len(selects), self._in_dim, self._out_dim)
        # return torch.bmm(x.unsqueeze(1), _weights).squeeze(1) + self.bias

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size
        _sample_size, _dim = x.shape
        if self.embedding_module is None:
            hs = torch.stack([self.embeddings[select] for select in selects]).mean(1)
        else:
            hs = self.embedding_module(torch.stack([self.embeddings[select] for select in selects]), selects)
        _weights = self.generator(hs).reshape(len(selects), self._in_dim, self._out_dim)
        return torch.bmm(x.reshape(len(selects), sample_batch_size, _dim),
                         _weights).reshape(_sample_size, self._out_dim) + self.bias

    def forward(self, x, select):
        return self.forward_batch(x, [select], len(x))

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.embeddings, a=math.sqrt(5))
        if self.bias is not None:
            bound = 1 / math.sqrt(self._out_dim)
            nn.init.uniform_(self.bias, -bound, bound)


class MetaMLP(nn.Module):
    def __init__(self, dims: list, embedding_dim, n=None, generator_hidden=(), embedding_module=None,
                 embedding_module_param=None):
        super(MetaMLP, self).__init__()

        if n is None:
            n = dims[0]

        self.n = n
        self._dims = dims
        self._embedding_dim = embedding_dim
        self._generator_hidden = generator_hidden

        self._mlp = nn.ModuleList()
        for i in range(len(self._dims) - 1):
            self._mlp += [MetaLinear(self._dims[i], self._dims[i + 1],
                                     self._embedding_dim, self.n, self._generator_hidden,
                                     embedding_module, embedding_module_param)]

    def forward(self, x, select):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward(x, select)
        return x

    def forward_selects(self, x, selects):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward_selects(x, selects)
        return x

    def forward_batch(self, x, selects, sample_batch_size: int):
        for i, l in enumerate(self._mlp):
            if i > 0:
                x = F.leaky_relu(x)
            x = l.forward_batch(x, selects, sample_batch_size)
        return x


class MetaFE(nn.Module):
    def __init__(self, dims, embedding_dim, generator_hidden=(), embedding_module=None,
                 embedding_module_param=None):
        super(MetaFE, self).__init__()
        self._dims = dims
        self._embedding_dim = embedding_dim
        self._generator_hidden = generator_hidden

        self._meta_mlp = MetaMLP(self._dims, self._embedding_dim, generator_hidden=self._generator_hidden,
                                 embedding_module=embedding_module, embedding_module_param=embedding_module_param)

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size
        _x = torch.zeros_like(x)
        for i, select in enumerate(selects):
            _x[i*sample_batch_size: (i+1)*sample_batch_size, select] = \
                x[i*sample_batch_size: (i+1)*sample_batch_size, select]
        _x = self._meta_mlp.forward_batch(_x, selects, sample_batch_size)
        return _x

    def forward(self, x, select):
        return self.forward_batch(x, [select], len(x))

    def forward_selects(self, x, selects):
        return self.forward_batch(x, selects, 1)


class NormalFE(nn.Module):
    def __init__(self, dims, embedding_dim=None, generator_hidden=(), embedding_module=None,
                 embedding_module_param=None):
        super(NormalFE, self).__init__()
        self._dims = dims
        self.embedding_dim = embedding_dim
        self.generator_hidden = generator_hidden
        self._mlp = MLP(*self._dims)

    def forward_batch(self, x, selects, sample_batch_size: int):
        assert len(x) == len(selects) * sample_batch_size
        _x = torch.zeros_like(x)
        for i, select in enumerate(selects):
            _x[i*sample_batch_size: (i+1)*sample_batch_size, select] = \
                x[i*sample_batch_size: (i+1)*sample_batch_size, select]
        return self._mlp(x)

    def forward(self, x, select):
        return self.forward_batch(x,  [select], len(x))

    def forward_selects(self, x, selects):
        return self.forward_batch(x, selects, 1)


if __name__ == "__main__":
    a = MetaLinear(10, 20, 8, 500, (32,))
    # 22.2 ms ± 322 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 10), list(range(5))) for i in range(128)]
    # 3.86 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_selects(torch.ones(128, 10), [list(range(5)) for i in range(128)])
    # 1.94 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 10), [list(range(5)) for i in range(64)], 2)
    # 193 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    a.forward(torch.ones(128, 10), list(range(5)))

if __name__ == "__main__":
    a = MetaMLP((100, 32, 64, 2), 32, 100)
    # 100 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 100), list(range(5))) for i in range(128)]
    # 17.3 ms ± 761 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_selects(torch.ones(128, 100), [list(range(5)) for i in range(128)])
    # 9.31 ms ± 674 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 100), [list(range(5)) for i in range(64)], 2)
    # 934 µs ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    a.forward(torch.ones(128, 100), list(range(5)))

if __name__ == "__main__":
    a = MetaFE((100, 32, 64, 2), 32)
    # 128 ms ± 5.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    [a.forward(torch.ones(1, 100), list(range(5))) for i in range(128)]
    # 24.4 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    a.forward_selects(torch.ones(128, 100),  [list(range(5)) for i in range(128)])
    # 12.5 ms ± 270 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    a.forward_batch(torch.ones(128, 100),  [list(range(5)) for i in range(64)], 2)
    # 1.28 ms ± 62.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    a.forward(torch.ones(128, 100),  list(range(5)))
