import logging
from typing import Sequence

import torch
from torch import nn
from torch.nn import functional as F

log = logging.getLogger(__name__)


class SumAggregation(nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()

        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"{__class__.__name__}: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, concat_subspaces: Sequence[torch.Tensor]) -> torch.Tensor:
        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)

        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]

        return torch.stack(out, dim=1).sum(dim=1)


class LinearSumAggregation(SumAggregation):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])


class NonLinearSumAggregation(SumAggregation):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.LayerNorm(subspace_dim),
                    nn.Linear(subspace_dim, subspace_dim),
                    nn.Tanh(),
                )
                for _ in range(num_subspaces)
            ]
        )


class ConcatAggregation(nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()

        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"ConcatAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

        self.norm_layers = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])

    @property
    def out_dim(self):
        return self.subspace_dim * self.num_subspaces

    def forward(self, concat_subspaces: torch.Tensor) -> torch.Tensor:
        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)

        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]

        return torch.cat(out, dim=1)


class WeightedAvgAggregation(nn.Module):  # TODO: fix this
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()
        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"WeightedAvgAggregation: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

        self.weight = nn.Parameter(torch.ones(num_subspaces))

        self.norm_layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.LayerNorm(subspace_dim),
                    # nn.Linear(subspace_dim, subspace_dim),
                    # nn.Tanh(),
                )
                for _ in range(num_subspaces)
            ]
        )

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, concat_subspaces: Sequence[torch.Tensor]) -> torch.Tensor:
        concat_subspaces = concat_subspaces.split(self.subspace_dim, dim=1)
        out = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, concat_subspaces)]

        softmax_weights = F.softmax(self.weight, dim=0)  # [num_subspaces]
        concat_subspaces = torch.stack(concat_subspaces, dim=1)  # [batch_size, num_subspaces, subspace_dim]
        out = torch.einsum("bns,n -> bs", concat_subspaces, softmax_weights)  # [batch_size, subspace_dim]

        return out


class Identity(nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()
        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"{__class__.__name__}: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

        assert self.num_subspaces == 1

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, relative_space: torch.Tensor) -> torch.Tensor:
        return relative_space


class NoAggregation(nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()
        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"{__class__.__name__}: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

        assert self.num_subspaces == 1

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, relative_space: torch.Tensor) -> torch.Tensor:
        return self.norm_layers(relative_space)


class LayerNorm(NoAggregation):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.LayerNorm(subspace_dim)


class MLP(NoAggregation):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.Sequential(
            nn.LayerNorm(subspace_dim),
            nn.Linear(subspace_dim, subspace_dim),
            nn.Tanh(),
        )


# check also https://github.com/sooftware/attentions/blob/master/attentions.py
class SelfAttentionLayer(torch.nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()
        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        log.info(f"SelfAttentionLayer: subspace_dim={self.subspace_dim}, num_subspaces={self.num_subspaces}")

        self.attention = nn.MultiheadAttention(embed_dim=self.subspace_dim, num_heads=1, batch_first=True)
        # self.norm_layers = nn.ModuleList([nn.LayerNorm(self.subspace_dim) for _ in range(self.num_subspaces)])

    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0) -> torch.Tensor:
        assert attention_idx == 0
        query = self.preattend(concat_subspaces)
        return self.attention(query=query, key=query, value=query)

    def preattend(self, concat_subspaces: Sequence[torch.Tensor]):
        query = concat_subspaces.split(self.subspace_dim, dim=1)
        query = [norm_layer(x) for norm_layer, x in zip(self.norm_layers, query)]
        query = torch.stack(query, dim=1)
        return query

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, concat_subspaces: Sequence[torch.Tensor]):
        query = self.preattend(concat_subspaces)
        out, _ = self.attention(query=query, key=query, value=query)
        return torch.sum(out, dim=1)


class LinearSelfAttentionLayer(SelfAttentionLayer):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.ModuleList([nn.LayerNorm(subspace_dim) for _ in range(num_subspaces)])


class NonLinearSelfAttentionLayer(SelfAttentionLayer):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__(subspace_dim, num_subspaces)

        self.norm_layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.LayerNorm(subspace_dim),
                    nn.Linear(subspace_dim, subspace_dim),
                    nn.Tanh(),
                )
                for _ in range(num_subspaces)
            ]
        )


class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=k, num_heads=heads, batch_first=True, dropout=0.1)
        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)
        self.ff = nn.Sequential(nn.Linear(k, k // 4), nn.ReLU(), nn.Linear(k // 4, k))

    def forward(self, x):
        attended = self.attention(query=x, key=x, value=x)[0]
        x = self.norm1(attended)
        fedforward = self.ff(x)
        return self.norm2(fedforward)

    def get_attention_weights(self, x):
        return self.attention(query=x, key=x, value=x)[1]


class Transformer(torch.nn.Module):
    def __init__(self, subspace_dim: int, num_subspaces: int):
        super().__init__()
        depth = 1
        heads = 1
        self.subspace_dim = subspace_dim
        self.num_subspaces = num_subspaces

        self.tblocks = nn.Sequential(*[TransformerBlock(k=self.subspace_dim, heads=heads) for _ in range(depth)])

    def get_attention_weights(self, concat_subspaces: Sequence[torch.Tensor], attention_idx: int = 0):
        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)
        attention_weights = self.tblocks[attention_idx].get_attention_weights(query)
        return attention_weights

    @property
    def out_dim(self):
        return self.subspace_dim

    def forward(self, concat_subspaces: Sequence[torch.Tensor]):
        query = concat_subspaces.reshape(-1, self.num_subspaces, self.subspace_dim)
        out = self.tblocks(query)
        return torch.mean(out, dim=1)
