from torch.nn import TransformerEncoderLayer
from torch import nn
import torch
from torch import Tensor
from typing import Optional
import torch.nn.functional as F
from tango.integrations.torch import Model

from typing import List

import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer
from .ptdec import DEC
from typing import List
from .components import InterpretableTransformerEncoder, BernoulliTransformerEncoder
from torchmetrics import Accuracy
from tango.integrations.torch import Model

from ...utils import expand_task_and_dataset


class TransPoolingEncoder(nn.Module):
    def __init__(self, transformer_type, input_feature_size, input_node_num, hidden_size, output_node_num, pooling=True, orthogonal=True, freeze_center=False, project_assignment=True, **kwargs):
        super().__init__()
        self.transformer_type = transformer_type
        self.transformer = HybridBlock(d_model=input_feature_size, nhead=4, dim_feedforward=hidden_size, batch_first=True, **kwargs)

        self.pooling = pooling
        if pooling:
            encoder_hidden_size = 32
            self.encoder = nn.Sequential(
                nn.Linear(input_feature_size *
                          input_node_num, encoder_hidden_size),
                nn.LeakyReLU(),
                nn.Linear(encoder_hidden_size, encoder_hidden_size),
                nn.LeakyReLU(),
                nn.Linear(encoder_hidden_size,
                          input_feature_size * input_node_num),
            )
            self.dec = DEC(cluster_number=output_node_num, hidden_dimension=input_feature_size, encoder=self.encoder,
                           orthogonal=orthogonal, freeze_center=freeze_center, project_assignment=project_assignment)

    def is_pooling_enabled(self):
        return self.pooling

    def forward(self, x, key=None):
        if "bernoulli" in self.transformer_type:
            x = self.transformer(x, key)
        else:
            x = self.transformer(x)
        if self.pooling:
            x, assignment = self.dec(x)
            return x, assignment
        return x, None

    def get_attention_weights(self):
        return self.transformer.get_attention_weights()

    def loss(self, assignment):
        return self.dec.loss(assignment)

@Model.register("hybrid")
class HyBRiD(Model):
    def __init__(self, n_mask: int, tasks: str, hidden_size: int, transformer_type: str = "bernoulli", agg_strategy: str = "sum", 
                 shared_mask: bool = False,
                 shared_dim_reduction: bool = True,
                shared_last: bool = True,
                 ):

        super().__init__()
        tasks = expand_task_and_dataset(tasks)
        self.n_mask = n_mask
        self.attention_list = nn.ModuleList()
        config_dataset_node_sz = 164
        forward_dim = config_dataset_node_sz
        config_model_freeze_center = True
        config_model_pos_embeded_dim = 360
        config_model_pos_encoding = "none"
        # config_model_pooling = [False, True]
        config_model_pooling = [False]
        config_model_orthogonal = True
        config_model_project_assignment = True
        # config_model_sizes = [360, 100]
        config_model_sizes = [172]

        self.pos_encoding = config_model_pos_encoding
        if self.pos_encoding == 'identity':
            self.node_identity = nn.Parameter(torch.zeros(
                config_dataset_node_sz, config_model_pos_embeded_dim), requires_grad=True)
            forward_dim = config_dataset_node_sz + config_model_pos_embeded_dim
            nn.init.kaiming_normal_(self.node_identity)

        sizes = config_model_sizes
        sizes[0] = config_dataset_node_sz
        in_sizes = [config_dataset_node_sz] + sizes[:-1]
        do_pooling = config_model_pooling
        self.do_pooling = do_pooling
        for index, size in enumerate(sizes):
            self.attention_list.append(
                TransPoolingEncoder(
                                    transformer_type=transformer_type,
                                    input_feature_size=forward_dim,
                                    input_node_num=in_sizes[index],
                                    hidden_size=hidden_size,
                                    output_node_num=size,
                                    pooling=do_pooling[index],
                                    orthogonal=config_model_orthogonal,
                                    freeze_center=config_model_freeze_center,
                                    project_assignment=config_model_project_assignment,
                                    n_mask=n_mask,
                                    tasks=tasks,
                                    agg_strategy=agg_strategy,
                                    shared_mask=shared_mask,
                                    ))

        if shared_dim_reduction:
            self.dim_reduction = nn.Sequential(
                nn.Linear(forward_dim, 32),
                nn.LeakyReLU(),
                nn.Linear(32, 8),
                nn.LeakyReLU(),
                nn.Linear(8, 1),
            )
        else:
            self.dim_reduction = nn.ModuleDict({
                task: nn.Sequential(
                    nn.Linear(forward_dim, 32),
                    nn.LeakyReLU(),
                    nn.Linear(32, 8),
                    nn.LeakyReLU(),
                    nn.Linear(8, 1),
                ) for task in tasks})
        if shared_last:
            self.last = nn.Linear(n_mask, 1)
        else:
            self.last = nn.ModuleDict({
                task: nn.Linear(n_mask, 1) for task in tasks})
            
        self.fc = nn.Identity()

    def forward(self, x, key, is_predicting: bool = False):
        node_feature = x
        bz, _, _, = node_feature.shape

        if self.pos_encoding == 'identity':
            pos_emb = self.node_identity.expand(bz, *self.node_identity.shape)
            node_feature = torch.cat([node_feature, pos_emb], dim=-1)

        assignments = []

        for atten in self.attention_list:
            node_feature, assignment = atten(node_feature, key)
            assignments.append(assignment)


        if isinstance(self.dim_reduction, nn.ModuleDict):
            dim_reduction = self.dim_reduction[key]
        else:
            dim_reduction = self.dim_reduction
        
        if isinstance(self.last, nn.ModuleDict):
            last_layer = self.last[key]
        else:
            last_layer = self.last

        node_feature = dim_reduction(node_feature)

        node_feature = node_feature.reshape((bz, -1))

        last = self.fc(node_feature)
        preds = last_layer(last).squeeze()
        # trues = batch[2].float()

        atten_weights = atten.get_attention_weights()
        mask, mask_logits = atten_weights
        
        return {
            "preds": preds,
            # "trues": trues,
            "last": last.detach(),
            "mask": mask,
            "mask_logits": mask_logits,
            "min_term": True,
        }

    def get_attention_weights(self):
        return [atten.get_attention_weights() for atten in self.attention_list]

    def get_cluster_centers(self) -> torch.Tensor:
        """
        Get the cluster centers, as computed by the encoder.

        :return: [number of clusters, hidden dimension] Tensor of dtype float
        """
        return self.dec.get_cluster_centers()

class MaskingModel(Model):
    def __init__(
            self,
            n_head: int = 1,
        ) -> None:
        super().__init__()
        
        self.mask = nn.Parameter(torch.Tensor(n_head, 164, 2))
        nn.init.xavier_normal_(self.mask)
        
    def forward(self):
        """
        x - [bs, n_nodes, dim]
        mask - [n_nodes]
        """
        mask_logits = self.mask
        mask_logits = torch.log(mask_logits.softmax(-1))

        # mask = F.gumbel_softmax(mask_logits, tau=1, hard=True)[..., 1]
        # mask = F.softmax(mask_logits, dim=-1)[..., 1]
        mask = self.straight_througth_max(mask_logits, dim=-1)[..., 1]

        return mask, mask_logits

    def straight_througth_max(self, logits, dim=-1, hard=True, tau = 1):
        logits = logits / tau  # ~Gumbel(logits,tau)
        y_soft = logits.softmax(dim)

        if hard:
            # Straight through.
            index = y_soft.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            # Reparametrization trick.
            ret = y_soft
        
        return ret

@Model.register("hybrid_block")
class HybridBlock(Model, TransformerEncoderLayer):
    def __init__(self, tasks: List[str], d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
                 layer_norm_eps=1e-5, batch_first=True, norm_first=False,
                 device=None, dtype=None, n_mask=-1, agg_strategy: str = "sum", shared_mask: bool = False) -> None:
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation,
                         layer_norm_eps, batch_first, norm_first, device, dtype)
        self.attention_weights: Optional[Tensor] = None
        
        masking_model = MaskingModel
        if shared_mask:
            self.mask = masking_model(n_mask)
        else:
            self.mask = nn.ModuleDict({task: masking_model(n_mask) for task in tasks})

        self.agg_strategy = agg_strategy
        self.cls = nn.Embedding(1, d_model)

    def _sa_block(self, x: Tensor, key: str,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        """
        Input:
            x - [bs, n_nodes, dim]
            mask - [n_nodes]
        Output:
            x - 
        """
        bs, n_nodes, dim = x.size()

        if isinstance(self.mask, nn.ModuleDict):
            masker = self.mask[key]
        else:
            masker = self.mask
        mask, mask_logits = masker()

        n_mask = mask.size(0)

        if self.agg_strategy == "sum":
            x = x[:, None, :, :] * mask[None, :, :, None]
            x = x.sum(-2) / (1e-7 + mask.sum(-1)[None, :, None])

        elif "attention" in self.agg_strategy:
            # all_absent = (mask.bool().sum(-1) == 0)
            attn_mask = torch.cat([
                torch.ones(mask.size(0), 1, dtype=mask.dtype, device=mask.device),
                mask
            ], dim=1)
            x = torch.cat([
                self.cls(torch.zeros(x.size(0), 1, dtype=torch.long, device=x.device)),
                x
            ], dim=1)
            x = x[:, None, :, :] * attn_mask[None, :, :, None]
            # attn_mask = mask.masked_fill(all_absent.unsqueeze(-1), 1.0)
            attn_mask = attn_mask.unsqueeze(-2).expand(-1, attn_mask.size(-1), -1)

            attn_mask = attn_mask.bool().repeat(x.size(0), 1, 1, 1)
            attn_mask = ~attn_mask
            size_store = x.size()
            x = x.view(-1, x.size(-2), x.size(-1))
            attn_mask = attn_mask.view(-1, attn_mask.size(-2), attn_mask.size(-1))

            attn_mask = attn_mask.repeat_interleave(self.self_attn.num_heads, dim=0)
            x, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
            x = x.view(*size_store)
            
            if self.agg_strategy == "self_attention":
            # x = x.masked_fill(all_absent[None, :, None, None], 0)
                x = x[:, :, 1:, :].sum(-2) / (1e-7 + mask.sum(-1)[None, :, None])
            elif self.agg_strategy == "cross_attention":
                x = x[:, :, 0, :]
            else:
                raise
        
        self.attention_weights = (mask, mask_logits[..., 1])

        return self.dropout1(x)

    def get_attention_weights(self) -> Optional[Tensor]:
        return self.attention_weights

    def forward(self, src: Tensor, key: str, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

        x = src
        if self.norm_first:
            x = self._sa_block(self.norm1(x), key, src_mask, src_key_padding_mask)
            x = self._ff_block(self.norm2(x))
        else:
            x = self.norm1(self._sa_block(x, key, src_mask, src_key_padding_mask))
            x = self.norm2(self._ff_block(x))

        return x
