import logging
from typing import List, Tuple

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, nn

from models.layers.proj_layer import ProjBlock
from models.csformer.encoder import cs_encoder
from models.vision.vision_encoder import VisionEncoder

logger = logging.getLogger(__name__)

class Model(nn.Module):
    def __init__(self, configs):
        super().__init__()
        self.hid_dim = configs.seq_len
        self.pred_horizon = configs.pred_len
        self.seq_len = configs.seq_len
        self.num_seg = configs.num_seg
        self.patch_len = self.seq_len//self.num_seg
        self.emb_hidden = 1024
        self.text_len = 1024
        self.audio_len = 512
        
        
        self.num_encoder = configs.num_encoder
        self.hidden_dim = configs.hidden_dim
        self.cl_ps = configs.cl_ps
        self.emb_dim = 768
        self.use_venc = configs.use_venc
        self.use_aenc = False
        self.use_tenc = False
        self.num_enc = 1
        self.fusion_method = configs.cross_modal
        self.decoder_method = 'none'
        self.norm_window = configs.norm_window
        self.vs_len = 512

        if self.num_enc == 0:
            raise ValueError("At least one modality encoder must be enabled.")

        if self.fusion_method == 'mlp':
            self.dropout = nn.Dropout(configs.dropout)
            self._venc_proj = ProjBlock(self.emb_dim, self.seq_len)
            self.nl_dim = self._fusion_dim()
            self.fc = nn.Linear(self.nl_dim, self.pred_horizon)
            self.gelu = nn.GELU()
        elif self.fusion_method == 'att':
            self.dropout = nn.Dropout(configs.dropout)
            self._venc_proj = ProjBlock(self.emb_dim, self.emb_dim)
            self.nl_dim = self._fusion_dim()
            self.fc = nn.Linear(self.nl_dim, self.pred_horizon)
            self.gelu = nn.GELU()
            self.ln = nn.LayerNorm(self.emb_hidden)
            self.linear = nn.Linear(1024 * self.num_enc, 1024 * self.num_enc)
        else:
            raise ValueError(f"Unsupported fusion method: {self.fusion_method}")

        self.cs_encoder = self._build_cs_layers(configs)

        if self.use_venc:
            logger.info('Vision encoder enabled.')
            self.vision_encoder = VisionEncoder(self.emb_dim, self.seq_len)
            self.vs_fc = nn.Linear(self.seq_len + self.emb_dim, 1024)

        self.pred_nn = nn.Linear(1024 * self.num_enc, self.pred_horizon)

        self._no_grad()


    def forward(
        self,
        x: Tensor,
        batch_x_mark,
        dec_inp,
        batch_y_mark,
        flatten_output: bool = True,
    ) -> Tensor:
        if self.fusion_method == 'mlp':
            return self.forward_decoder_mlp(
                x, batch_x_mark, dec_inp, batch_y_mark, flatten_output=flatten_output
            )
        elif self.fusion_method == 'att':
            return self.forward_decoder_att(
                x, batch_x_mark, dec_inp, batch_y_mark, flatten_output=flatten_output
            )
        raise ValueError(f"Unsupported fusion method: {self.fusion_method}")

    def forward_decoder_att(
        self, x: Tensor, batch_x_mark, dec_inp, batch_y_mark, flatten_output: bool = True
    ) -> Tensor:
        x_raw, x_norm, mean, std = self._normalize_input(x)

        hiddens: List[Tensor] = []
        if self.use_venc:
            hiddens.append(self._encode_vision_branch(x_raw, x_norm))

        if not hiddens:
            raise ValueError("No modality outputs available for attention fusion.")

        fused = torch.cat(hiddens, dim=-1)
        fused = self._apply_cs_stack(fused, residual=True)
        preds = self.pred_nn(fused)
        preds = self._denormalize(preds, mean, std)
        return rearrange(preds, 'b m f -> b f m')

    def forward_decoder_mlp(
        self, x: Tensor, batch_x_mark, dec_inp, batch_y_mark, flatten_output: bool = True
    ) -> Tensor:
        x_raw, x_norm, mean, std = self._normalize_input(x)
        features = self._collect_mlp_modalities(x_raw, x_norm)
        features = self._apply_cs_stack(features, residual=False)
        features = self.dropout(features)
        features = self.fc(features)
        features = self._denormalize(features, mean, std)
        return rearrange(features, 'b m f -> b f m')

    def _fusion_dim(self):
        return int(self.seq_len + self.use_venc * self.vs_len)

    def _build_cs_layers(self, configs) -> nn.ModuleList:
        num_layers = max(1, self.num_encoder)
        return nn.ModuleList([cs_encoder(configs) for _ in range(num_layers)])

    def _no_grad(self):
        if self.use_venc:
            for param in self.vision_encoder.parameters():
                param.requires_grad = False
            self.vision_encoder.eval() 


    def cross_attention(self, query, key, value, n_heads=8, n_dim=4):
        if n_dim==4:
            B, M, L, h = query.shape
            S = key.shape[2]
            head_dim = h // n_heads
            assert h % n_heads == 0, "h must be divisible by n_heads"
            def split_heads(x):
                return rearrange(x, 'B M T (nh hd) -> B M nh T hd', nh=n_heads, hd=head_dim)
            q = split_heads(query)
            k = split_heads(key)
            v = split_heads(value)
            out = F.scaled_dot_product_attention(q, k, v)
            out = rearrange(out, 'B M nh L hd -> B M L (nh hd)', B=B, M=M, nh=n_heads, hd=head_dim)
        elif n_dim==3:
            B, L, h = query.shape
            S = key.shape[1]
            head_dim = h // n_heads
            assert h % n_heads == 0, "h must be divisible by n_heads"
            def split_heads(x):
                return rearrange(x, 'B T (nh hd) -> B nh T hd', nh=n_heads, hd=head_dim)
            q = split_heads(query)
            k = split_heads(key)
            v = split_heads(value)
            out = F.scaled_dot_product_attention(q, k, v)
            out = rearrange(out, 'B nh L hd -> B L (nh hd)', B=B, nh=n_heads, hd=head_dim)
        return out

    def _normalize_input(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        x_raw = rearrange(x, 'b l m -> b m l')
        mean = x_raw[:, :, -self.norm_window:].mean(dim=2, keepdim=True)
        std = x_raw[:, :, -self.norm_window:].std(dim=2, keepdim=True)
        x_norm = (x_raw - mean) / (std + 1e-3)
        return x_raw, x_norm, mean, std

    def _denormalize(self, x: Tensor, mean: Tensor, std: Tensor) -> Tensor:
        return x * (std + 1e-3) + mean

    def _encode_vision_branch(self, x_raw: Tensor, x_norm: Tensor) -> Tensor:
        x_venc = self.vision_encoder.embedding_cc1(x_raw)
        x_venc = self._venc_proj(x_venc)
        x_venc = torch.cat([x_venc, x_norm], dim=-1)
        return self.vs_fc(x_venc)

    def _collect_mlp_modalities(self, x_raw: Tensor, x_norm: Tensor) -> Tensor:
        features: List[Tensor] = []
        if self.use_venc:
            features.append(self._venc_proj(self.vision_encoder.embedding_cc1(x_raw)))
        features.append(x_norm)
        return torch.cat(features, dim=-1)

    def _apply_cs_stack(self, x: Tensor, residual: bool) -> Tensor:
        if residual:
            for layer in self.cs_encoder:
                x = x + layer(x)
            return x

        if len(self.cs_encoder) == 1:
            return self.cs_encoder[0](x)

        for layer in self.cs_encoder:
            x = x + layer(x)
        return x


