import math
import torch
import torch.nn as nn
import logging
import torch.nn.functional as F
from utils.data_utils import Example, Feature
from model.architectures.set_transformer import ISAB, PMA
from utils.transformer_utils import get_tokenizer, get_model

logger = logging.getLogger(__name__)

class FeatureEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dtype_embed = nn.Embedding(2, config.model_dim)
        self.desc_proj = nn.Linear(config.embedding_dim, config.model_dim)
        if getattr(config, 'use_set_transformer', False):
            self.class_pool = nn.Sequential(
                ISAB(config.model_dim, config.model_dim, config.num_heads, config.num_inds),
                PMA(config.model_dim, config.num_heads, 1)
            )
        else:
            self.class_pool = None
        self.embedding_cache = {}
        
    def forward(self, feature: Feature):
        # Get the device directly from module parameters
        device = self.desc_proj.weight.device
        
        # Check cache first
        cache_key = id(feature)
        if cache_key in self.embedding_cache:
            # Always move cached embeddings to current device
            return self.embedding_cache[cache_key].to(device)
            
        # 1) Description embedding -> model_dim (ensure on correct device)
        # Move embedding to device first, then compute projection
        embedding = feature.description_embedding[self.config.embedding_model]
        if isinstance(embedding, torch.Tensor):
            desc = embedding.to(device)
        else:
            desc = torch.tensor(embedding, device=device)
        
        desc = self.desc_proj(desc)
        
        # 2) Dtype embedding
        dtype_idx = 0 if feature.dtype == 'real' else 1
        dt = self.dtype_embed(torch.tensor(dtype_idx, device=device))
        out = desc + dt
        
        # 3) If categorical and pooling enabled, add pooled category prototypes
        if feature.dtype == 'categorical' and self.class_pool is not None:
            cat_embs = []
            for cat in feature.categories:
                cat_emb = feature.categories_embedding[self.config.embedding_model].get(cat)
                if cat_emb is not None:
                    # Ensure embeddings are on the correct device
                    if isinstance(cat_emb, torch.Tensor):
                        cat_embs.append(cat_emb.to(device))
                    else:
                        cat_embs.append(torch.tensor(cat_emb, device=device))
                        
            if cat_embs:
                cat_tensor = torch.stack(cat_embs, dim=0).unsqueeze(0)  # [1, C, emb_dim]
                pooled = self.class_pool(cat_tensor).squeeze(0).squeeze(0)  # [model_dim] #TODO: change it to some PIV layer
                out = out + pooled
                
        # Store detached tensor in cache (on same device)
        self.embedding_cache[cache_key] = out.detach()
        return out        
class FeatureValuePairEmbedding(nn.Module):
    """
    Embeds one (feature, value) pair into a token embedding for BERT.
    Output dimension: config.model_dim
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.feature_embed = FeatureEmbedding(config)
        self.real_proj = nn.Linear(1, config.model_dim)
        self.missing_embed = nn.Embedding(1, config.model_dim)
        self.embed_proj = nn.Linear(config.model_dim * 2, config.model_dim)
        self.register_buffer('_dummy_tensor', torch.zeros(1))
        # Cache for categorical value embeddings
        self.cat_value_cache = {}
    @property
    def device(self):
        return self._dummy_tensor.device
    def forward(self, feature: Feature, value):
        # Get device consistently
        device = next(self.parameters()).device
        
        # 1) Feature description + dtype -> model_dim
        f_emb = self.feature_embed(feature)
        # Ensure f_emb is on the right device
        f_emb = f_emb.to(device)

        # 2) Value embedding
        if value is None or (isinstance(value, float) and math.isnan(value)):
            # Missing or NaN value
            v_emb = self.missing_embed(torch.zeros(1, dtype=torch.long, device=device)).squeeze(0)
        elif feature.dtype == 'real':
            # Real value: project scalar
            v_emb = self.real_proj(torch.tensor([[value]], device=device, dtype=torch.float)).squeeze(0)
        else:
            # Categorical value: check cache first
            cache_key = (id(feature), value)
            if cache_key in self.cat_value_cache:
                v_emb = self.cat_value_cache[cache_key].to(device)
            else:
                # Categorical value: clamp index into valid range
                idx = int(value)
                num_cats = len(feature.categories)
                if num_cats > 0:
                    # clamp between 0 and num_cats-1
                    idx = max(0, min(idx, num_cats - 1))
                    cat = feature.categories[idx]
                    # Get embedding and explicitly move to device
                    emb = feature.categories_embedding[self.config.embedding_model][cat]
                    v_emb = torch.tensor(emb, device=device) if not isinstance(emb, torch.Tensor) else emb.to(device)
                else:
                    # no categories defined, treat as missing
                    v_emb = self.missing_embed(torch.zeros(1, dtype=torch.long, device=device)).squeeze(0)
                
                # Cache the result (on CPU)
                self.cat_value_cache[cache_key] = v_emb.detach().cpu()

        # 3) Combine and project → Feature-Value Pair Embedding Vector
        combined = torch.cat([f_emb, v_emb], dim=-1)
        return self.embed_proj(combined)

class PEQRowEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.feature_value_embed = FeatureValuePairEmbedding(config)
        self.peq_layer = nn.Sequential(
            nn.LayerNorm(config.model_dim),
            nn.Linear(config.model_dim, config.model_dim),
            nn.ReLU()
        )
        # Cache for row embeddings
        self.row_cache = {}
        
    def forward(self, features: list[Feature], values: list[float]):
        # Get device directly from parameters
        device = next(self.parameters()).device
        
        # Check cache first
        cache_key = (tuple(id(f) for f in features), tuple(values))
        if cache_key in self.row_cache:
            # Move cached result to current device
            return self.row_cache[cache_key].to(device)
        
        try:
            # Process all feature-value pairs and ensure they're on the device
            embs = []
            for f, v in zip(features, values):
                # Get embedding and ensure it's on the device
                emb = self.feature_value_embed(f, v).to(device)
                embs.append(emb)
                
            # Stack all embeddings
            if embs:
                stacked = torch.stack(embs, dim=0)
                result = self.peq_layer(stacked)
                
                # Cache result (on CPU to save memory)
                self.row_cache[cache_key] = result.detach().cpu()
                
                return result
            else:
                # Return empty tensor if no embeddings
                return torch.zeros((0, self.config.model_dim), device=device)
        except Exception as e:
            logger.error(f"Error in PEQRowEmbedding.forward: {e}")
            return torch.zeros((len(features), self.config.model_dim), device=device)
class MDNRegressionHead(nn.Module):
    """
    Mixture Density Network head producing a mixture of Gaussians.
    Outputs (pi, mu, sigma) each of shape [batch, num_mixtures].
    """
    def __init__(self, hidden_dim, num_mixtures):
        super().__init__()
        self.num_mixtures = num_mixtures
        self.pi_layer = nn.Linear(hidden_dim, num_mixtures)
        self.mu_layer = nn.Linear(hidden_dim, num_mixtures)
        self.sigma_layer = nn.Linear(hidden_dim, num_mixtures)

    def forward(self, h):
        pi = F.softmax(self.pi_layer(h), dim=-1)
        mu = self.mu_layer(h)
        sigma = torch.exp(self.sigma_layer(h))
        return pi, mu, sigma


class PermutationEquivariantClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding_dim = config.embedding_dim
        
        # Project input features to right dimension
        self.input_proj = nn.Linear(config.model_dim, self.embedding_dim)
        
        # Number of layers (with default)
        num_layers = getattr(config, 'num_equivariant_layers', 2)
        
        # Build equivariant layers more efficiently
        first_block = EquivariantBlock(self.embedding_dim * 2, self.embedding_dim)
        other_blocks = [EquivariantBlock(self.embedding_dim, self.embedding_dim) for _ in range(num_layers - 1)]
        self.equivariant_layers = nn.ModuleList([first_block] + other_blocks)
        
        # Final projection to scalar logits
        self.logit_proj = nn.Linear(self.embedding_dim, 1)
        
    def forward(self, feature_rep, class_embeddings):
        # Project feature_rep to embedding_dim
        feature_rep = self.input_proj(feature_rep)  # [embedding_dim]
        
        # Add batch dimension if needed
        if feature_rep.dim() == 1:
            feature_rep = feature_rep.unsqueeze(0)  # [1, embedding_dim]
        
        batch_size = feature_rep.size(0)
        num_classes = class_embeddings.size(0)
        
        # Expand feature representation for each class - more efficiently
        expanded_feature = feature_rep.unsqueeze(1).expand(-1, num_classes, -1)
        
        # Expand class embeddings for each batch - more efficiently
        expanded_classes = class_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Concatenate along feature dimension
        combined = torch.cat([expanded_feature, expanded_classes], dim=-1)
        
        # Process through equivariant layers
        x = combined
        for layer in self.equivariant_layers:
            x = layer(x)
        
        # Project to get logits for each class
        logits = self.logit_proj(x).squeeze(-1)
        
        return logits


class EquivariantBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        
        # Shared MLPs for permutation equivariance
        self.phi = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.LayerNorm(output_dim),
            nn.ReLU()
        )
        
        # Optional: add attention between classes
        self.attention = nn.MultiheadAttention(
            embed_dim=output_dim,
            num_heads=4,
            batch_first=True
        )
        
        # Layer norm and residual connection
        self.norm = nn.LayerNorm(output_dim)
        
    def forward(self, x):
        # Apply shared MLP to each class independently
        x = self.phi(x)
        
        # Apply self-attention across classes
        attn_out, _ = self.attention(x, x, x)
        
        # Add residual connection and normalize
        x = self.norm(attn_out + x)
        
        return x


# ——— PEQ-BERT layer and integration ———
class PEQBertTransformerLayer(nn.Module):
    """
    Replaces BertSelfAttention with a permutation-equivariant block:
      ISAB (Induced Self-Attention) → PMA (Pooling by Multihead Attention).
    """
    def __init__(self, config):
        super().__init__()
        self.isab = ISAB(
            embed_dim=config.model_dim,
            num_inds=config.num_inds,
            num_heads=config.num_heads,
        )
        self.pma = PMA(
            embed_dim=config.model_dim,
            num_heads=config.num_heads,
            num_seeds=1,
        )

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        # hidden_states: [batch_size, seq_len, model_dim]
        x = self.isab(hidden_states)
        x = self.pma(x)
        seq_len = hidden_states.size(1)
        peq_out = x.expand(-1, seq_len, -1)
        if output_attentions:
            return (peq_out, None)
        return (peq_out,)

class FSLModel(nn.Module):
    """
    PEQ-BERT few-shot model that matches the figure:
    1) Tokenize description into BERT embeddings.
    2) Append one PEQ token per feature-value for each few-shot row and the target row.
    3) Run through BERT.
    4) Extract the hidden state at the target column position.
    5) Regression via MDN, classification via dot-product with class prototypes.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.row_embed = PEQRowEmbedding(config)
        self.tokenizer = get_tokenizer(config.model_name)
        self.model = get_model(config.model_name)
        # PEQ-BERT hook
        if getattr(config, 'use_peq_bert', False):
            self._replace_attention_layers()
        self.temperature = getattr(config, 'temperature', 1.0)
        self.class_head = nn.Linear(config.model_dim, config.model_dim)
        self.reg_head = MDNRegressionHead(config.model_dim, config.num_mixtures)
        self.register_buffer('_dummy_tensor', torch.zeros(1))
        self.desc_cache = {}
        self._embedding_layer = None
        self.example_cache = {}

    def _replace_attention_layers(self):
        """
        Walk through each Transformer layer and swap out the .self
        module in BertAttention for our PEQBertTransformerLayer.
        """
        for layer in self.model.encoder.layer:
            layer.attention.self = PEQBertTransformerLayer(self.config)

    @property
    def embedding_layer(self):
        if self._embedding_layer is None:
            self._embedding_layer = self.model.get_embedding_layer()
        return self._embedding_layer

    @property
    def device(self):
        return self._dummy_tensor.device

    def _prepare_example(self, example: Example):
        device = next(self.parameters()).device
        cache_key = id(example)
        if cache_key in self.example_cache:
            seq, mask, pos = self.example_cache[cache_key]
            return seq.to(device), mask.to(device), pos
        desc_key = example.description
        if desc_key in self.desc_cache:
            desc_emb = self.desc_cache[desc_key].to(device)
        else:
            toks = self.tokenizer(example.description, return_tensors="pt")
            input_ids = toks.input_ids.to(device)
            desc_emb = self.embedding_layer(input_ids).squeeze(0)
            self.desc_cache[desc_key] = desc_emb.detach().cpu()
            desc_emb = desc_emb.to(device)
        fs_embs = []
        for i, row in enumerate(example.fewshot_rows):
            try:
                emb = self.row_embed(example.features, row).to(device)
            except Exception as e:
                logger.warning(f"Error in few-shot row {i} embedding: {e}")
                emb = torch.zeros((len(example.features), self.config.model_dim), device=device)
            fs_embs.append(emb)
        excluded = set(example.missing_column_ids + [example.target_column_id])
        tgt_vals = [v if idx not in excluded else None for idx, v in enumerate(example.target_row)]
        try:
            tgt_emb = self.row_embed(example.features, tgt_vals).to(device)
        except Exception as e:
            logger.warning(f"Error in target row embedding: {e}")
            tgt_emb = torch.zeros((len(example.features), self.config.model_dim), device=device)
        seq = torch.cat([desc_emb] + fs_embs + [tgt_emb], dim=0)
        attn_mask = torch.ones(1, seq.size(0), device=device)
        desc_len = desc_emb.size(0)
        fs_len = sum(m.size(0) for m in fs_embs)
        target_pos = desc_len + fs_len + example.target_column_id
        self.example_cache[cache_key] = (
            seq.unsqueeze(0).detach().cpu(),
            attn_mask.detach().cpu(),
            target_pos
        )
        max_allowed = self.model.config.max_position_embeddings
        return seq.unsqueeze(0), attn_mask, target_pos

    def _prepare_batch(self, batch: list[Example]):
        device = self.device
        processed = [self._prepare_example(ex) for ex in batch]
        embeds, masks, positions = zip(*processed)
        max_len = max(x.size(1) for x in embeds)
        padded_embeds = torch.zeros((len(batch), max_len, embeds[0].size(2)), device=device)
        padded_masks = torch.zeros((len(batch), max_len), device=device)
        for i, (emb, mask) in enumerate(zip(embeds, masks)):
            seq_len = emb.size(1)
            padded_embeds[i, :seq_len] = emb.squeeze(0)
            padded_masks[i, :seq_len] = mask.squeeze(0)
        if padded_embeds.size(1) > self.model.config.max_position_embeddings:
            padded_embeds = padded_embeds[:, :self.model.config.max_position_embeddings, :]
            padded_masks = padded_masks[:, :self.model.config.max_position_embeddings]
            positions = [min(pos, self.model.config.max_position_embeddings - 1) for pos in positions]
        return padded_embeds, padded_masks, positions, batch

    def _create_position_ids(self, batch_size, seq_len, desc_len, device):
        return torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)

    def forward(self, batch: list[Example]):
        device = next(self.parameters()).device
        inputs_embeds, attention_mask, positions, examples = self._prepare_batch(batch)
        if inputs_embeds.size(0) == 0:
            return torch.tensor(0.0, requires_grad=True, device=device)
        batch_size, seq_len, _ = inputs_embeds.shape
        token_type_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
        ex = examples[0]
        desc_key = ex.description
        if desc_key in self.desc_cache:
            desc_emb = self.desc_cache[desc_key]
        else:
            toks = self.tokenizer(ex.description, return_tensors="pt").to(device)
            desc_emb = self.embedding_layer(toks.input_ids).squeeze(0)
            self.desc_cache[desc_key] = desc_emb
        desc_len = desc_emb.size(0)
        position_ids = self._create_position_ids(batch_size, seq_len, desc_len, device)
        out = self.model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            return_dict=True
        )
        h = out.last_hidden_state
        target_h = torch.stack([h[i, positions[i]] for i in range(len(positions))], dim=0)
        pi, mu, sigma = self.reg_head(target_h)
        prototype_cache = {}
        losses = []
        for i, ex in enumerate(examples):
            feat = ex.features[ex.target_column_id]
            if feat.dtype == 'real':
                y_t = torch.tensor(ex.target_row[ex.target_column_id], device=device).float()
                y_p = (pi * mu).sum(dim=-1)
                losses.append(F.mse_loss(y_p, y_t))
            else:
                vec = self.class_head(target_h[i])
                cats = getattr(feat, 'categories', []) or []
                num_cats = len(cats)
                if feat.dtype != 'categorical' or num_cats == 0:
                    logits = vec.new_zeros((1,))
                else:
                    feat_id = id(feat)
                    if feat_id not in prototype_cache:
                        protos = torch.stack([
                            feat.categories_embedding[self.config.embedding_model][c]
                            for c in cats
                        ], dim=0).to(device)
                        prototype_cache[feat_id] = protos
                    else:
                        protos = prototype_cache[feat_id]
                    logits = vec.unsqueeze(0) @ protos.t()
                    target_idx = int(ex.target_row[ex.target_column_id])
                    target_idx = max(0, min(target_idx, num_cats-1))
                    tgt = torch.tensor([target_idx], device=device)
                    losses.append(F.cross_entropy(logits, tgt))
        return torch.stack(losses).mean() if losses else torch.tensor(0.0, requires_grad=True, device=device)