import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import copy

import torch
import torch.nn as nn
import numpy as np

import torch
import torch.nn.functional as F

def moegate_entropy_loss(gate_probs: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    # Clip probabilities to avoid log(0)
    probs = torch.clamp(gate_probs, min=eps, max=1.0)
    
    # Compute entropy per sample
    entropy = -torch.sum(probs * torch.log(probs), dim=-1)  # [B]
    
    # Average across batch
    entropy_loss = -entropy.mean()  # negative because we want to maximize entropy
    
    return entropy_loss


# Import only what we need, avoiding skorch dependency
def compute_same_pad1d(input_size, kernel_size, stride=1, dilation=1):
    all_padding = (stride - 1) * input_size - stride + dilation * (kernel_size - 1) + 1
    return (all_padding // 2, all_padding - all_padding // 2)
def compute_same_pad2d(input_size, kernel_size, stride=(1, 1), dilation=(1, 1)):
    ud = compute_same_pad1d(
        input_size[0], kernel_size[0], stride=stride[0], dilation=dilation[0]
    )
    lr = compute_same_pad1d(
        input_size[1], kernel_size[1], stride=stride[1], dilation=dilation[1]
    )
    return [*lr, *ud]

def _narrow_normal_weight_zero_bias(module):

    for name, param in module.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param, mean=0.0, std=0.01)
        elif 'bias' in name:
            nn.init.zeros_(param)

def compute_out_size(input_size, kernel_size, stride=1, padding=0, dilation=1):

    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(stride, int):
        stride = (stride, stride)
    if isinstance(padding, int):
        padding = (padding, padding)
    if isinstance(dilation, int):
        dilation = (dilation, dilation)
    
    h_out = (input_size + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
    w_out = (input_size + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
    
    return h_out, w_out



class MoETemporalConv2d(nn.Module):

    def __init__(
        self,
        in_channels: int,
        gate_in_ch: int,
        out_channels: int,
        time_kernel: tuple,
        num_experts: int = 8,
        top_k: int = 2,
        gate_hidden: int | None = None,
        use_aux_balance: bool = True,
        aux_coef: float = 0.01,
        stride: tuple = (1, 1),
        enable_pad: bool = False,
    ):
        super().__init__()
        assert 1 <= top_k <= num_experts
        self.in_ch = in_channels
        self.gate_in_ch = gate_in_ch
        self.out_ch = out_channels
        self.k = time_kernel[1]
        self.stride = stride
        self.time_kernel = time_kernel
        self.num_experts = num_experts
        self.top_k = top_k
        self.use_aux = use_aux_balance
        self.aux_coef = aux_coef

        # Experts: identical shapes, different weights
        self.experts = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=time_kernel, stride=stride)
            for _ in range(num_experts)
        ])

        # Gating MLP (per sample): GAP → MLP → logits[E]
        gh = gate_hidden or max(16, in_channels // 2)
        self.gate = nn.Sequential(
            nn.Linear(gate_in_ch, gh),
            nn.ReLU(inplace=True),
            nn.Linear(gh, num_experts)
        )

        # We’ll compute SAME padding dynamically for width=T
        self.pad_cache_T = None
        self.pad_layer = None
        self.enable_pad = enable_pad
        self.aux_loss = None  # set on forward

    def _ensure_same_pad(self, T: int, device):
        if self.pad_cache_T == T and self.pad_layer is not None:
            return
        # SAME pad for (H=1, W=T) with kernel (1, k), stride (1,1)
        # Left/right padding only depends on W and k
        total = self.k - 1
        left = total // 2
        right = total - left
        # nn.ConstantPad2d pads (left, right, top, bottom)
        self.pad_layer = nn.ConstantPad2d((left, right, 0, 0), 0.0).to(device)
        self.pad_cache_T = T

    def forward(self, x: torch.Tensor) -> torch.Tensor:
 
        B, C, H, T = x.shape
        if not self.enable_pad:
            _,T = compute_out_size(T, self.time_kernel[1], stride=self.stride[1])
        if self.enable_pad:
            self._ensure_same_pad(T, x.device)

        # Gate: GAP over (H=1, W=T) → [B, C]
        if H==1:
            feat = x.mean(dim=-1).squeeze(-1)  # [B, C]
        else:
            feat = x.mean(dim=-1).squeeze(-2)  # [B, C]
        logits = self.gate(feat)           # [B, E]
        gate_probs = F.softmax(logits, dim=-1)  # [B, E]

        # Optional load-balance aux loss
        self.aux_loss = None
        if self.use_aux:
            mean_usage = gate_probs.mean(dim=0)  # [E]
            uniform = torch.full_like(mean_usage, 1.0 / self.num_experts)
            self.aux_loss = self.aux_coef * F.kl_div(
                (mean_usage + 1e-8).log(), uniform, reduction="batchmean"
            )

        # Top-k routing
        topk_val, topk_idx = gate_probs.topk(self.top_k, dim=-1)  # [B, k]
        weights = topk_val / (topk_val.sum(dim=-1, keepdim=True) + 1e-9)

        # Pre-pad once
        if self.enable_pad:
            x_pad = self.pad_layer(x)  # [B, C, 1, T_pad]
        else:
            x_pad = x

        # Aggregate expert outputs
        y = torch.zeros(B, self.out_ch, 1, T, device=x.device, dtype=x.dtype)
        for e in range(self.num_experts):
            mask = (topk_idx == e)  # [B, k] bool
            if not mask.any():
                continue
            b_idx, pos = mask.nonzero(as_tuple=True)
            uniq_b = b_idx.unique()
            x_sub = x_pad[uniq_b]
            out_e = self.experts[e](x_sub)
            w_per_b = torch.zeros(len(uniq_b), device=x.device, dtype=x.dtype)
            for i, b in enumerate(uniq_b):
                where = (b_idx == b)
                w_per_b[i] = weights[b, pos[where]].sum()
            w_per_b = w_per_b.view(-1, 1, 1, 1)
            y[uniq_b] += out_e * w_per_b
        return y
class GuneyNet_MoE_v3(nn.Module):
    def __init__(
        self,
        n_channels, n_samples, n_classes, n_bands,
        spatial_dropout=0.1,
        time2_kernel=10,
        time2_dropout=0.95,
        moe_num_experts=4,
        moe_top_k=2,
        moe_gate_hidden=None,
        moe_aux_coef=0.01,
        moe_switch = 'time1_time2',
        n_spatial_filters = 120,
        n_time1_filters = 120,
        n_time2_filters = 120,
    ):
        super().__init__()
        self.moe_switch = moe_switch
        time1_kernel = 2
        time1_stride = 2
        self.moe_num_experts = moe_num_experts
        self.moe_top_k = moe_top_k
        self.time2_kernel = time2_kernel

        self.n_channels = n_channels
        self.n_samples = n_samples
        self.n_classes = n_classes
        self.n_bands = n_bands
        self.n_spatial_filters = n_spatial_filters
        self.n_time1_filters = n_time1_filters
        self.n_time2_filters = n_time2_filters

        # compute out length after time1
        temp_out_size_1 = compute_out_size(n_samples, time1_kernel, stride=time1_stride)
        # Extract the width dimension (temporal dimension)
        if isinstance(temp_out_size_1, tuple):
            temp_out_size_1 = temp_out_size_1[1]  # Take width dimension

        self.band_layer = nn.Conv2d(n_bands*n_channels, n_spatial_filters, (1, 1))
        if 'spatial' in self.moe_switch:
            self.spatial_layer_1 = MoETemporalConv2d(
                in_channels=n_spatial_filters,  
                gate_in_ch=n_spatial_filters,
                out_channels=n_spatial_filters,
                time_kernel=(1, 1),
                num_experts=moe_num_experts,
                top_k=moe_top_k,
                enable_pad=False,

            )
        else:
            self.spatial_layer_1 = nn.Conv2d(n_spatial_filters, n_spatial_filters, (1, 1))
        self.spatial_dropout_1 = nn.Dropout(spatial_dropout)
        if 'time1' in self.moe_switch:
            self.time1_layer = MoETemporalConv2d(
                in_channels=n_spatial_filters,
                gate_in_ch=n_spatial_filters,
                out_channels=n_time1_filters,
                time_kernel=(1, time1_kernel),
                num_experts=moe_num_experts,
                top_k=moe_top_k,
                enable_pad=False,
                stride=(1, time1_stride)
            )
        else:
            self.time1_layer = nn.Conv2d(n_spatial_filters, n_time1_filters, (1, time1_kernel), stride=(1, time1_stride))
        self.relu = nn.ReLU()
        self.time1_dropout = nn.Dropout(spatial_dropout)

        # === MoE replaces manual padding + time2_layer ===
        if 'spatial2' in self.moe_switch:
            self.spatial_layer_2 = MoETemporalConv2d(
                in_channels=n_time1_filters,  
                gate_in_ch=n_time1_filters,
                out_channels=n_spatial_filters,
                time_kernel=(1, 1),
                num_experts=moe_num_experts,
                top_k=moe_top_k,
                enable_pad=False,

            )
        else:
            self.spatial_layer_2 = nn.Conv2d(n_time1_filters, n_spatial_filters, (1, 1))
        self.spatial_dropout_2 = nn.Dropout(spatial_dropout)
        if 'time2' in self.moe_switch:
            self.time2_moe = MoETemporalConv2d(
                in_channels=n_spatial_filters,
                gate_in_ch=n_spatial_filters,
                out_channels=n_time2_filters,
                time_kernel=(1, time2_kernel),
                num_experts=moe_num_experts,
                top_k=moe_top_k,
                enable_pad=True,
                stride=(1, 1)
            )
        else:
            self.time2_moe = nn.Sequential(
                nn.ConstantPad2d(compute_same_pad2d((1, temp_out_size_1), (1, time2_kernel), stride=(1, 1)), 0.0),
                nn.Conv2d(n_spatial_filters, n_time2_filters, (1, time2_kernel)),
            )
        self.relu2 = nn.ReLU()
        self.time2_dropout = nn.Dropout(time2_dropout)

        self.flatten = nn.Flatten()
        self.fc_layer = nn.Linear(n_time2_filters * temp_out_size_1, n_classes)

        self._reset_parameters()

        # expose aux loss (updated on forward)
        self.moe_aux_coef = moe_aux_coef
        self.last_aux_loss = None

    @torch.no_grad()
    def _reset_parameters(self):
        _narrow_normal_weight_zero_bias(self)
        nn.init.ones_(self.band_layer.weight)
        nn.init.xavier_normal_(self.fc_layer.weight, gain=1)

    def forward(self, X):
        self.last_aux_loss = torch.tensor(0.0).to(X.device)
        if X.dtype != torch.float32:
            X = X.to(torch.float32)
        X = X.reshape(X.shape[0],X.shape[1]*X.shape[2], 1, X.shape[3])
        x = self.band_layer(X)
        x = self.spatial_layer_1(x)
        x = self.spatial_dropout_1(x)
        x = self.time1_layer(x)                # [B, n_time1, 1, T1]
        x = self.relu(x)
        x = self.time1_dropout(x)

        x = self.spatial_layer_2(x)
        x = self.spatial_dropout_2(x)
        x = self.time2_moe(x)                  # [B, n_time2, 1, T1]
        x = self.time2_dropout(x)

        y = self.flatten(x)
        y = self.fc_layer(y)

        # surface aux loss for training
        if 'spatial' in self.moe_switch:
            self.last_aux_loss += self.spatial_layer_1.aux_loss
        if 'spatial2' in self.moe_switch:
            self.last_aux_loss += self.spatial_layer_2.aux_loss
        if 'time1' in self.moe_switch:
            self.last_aux_loss += self.time1_layer.aux_loss
        if 'time2' in self.moe_switch:
            self.last_aux_loss += self.time2_moe.aux_loss
        return y
    
    def copy_model(self):
        # Create a new instance
        new_model = self.__class__(
            
            n_channels=self.n_channels,
            n_samples=self.n_samples,
            n_classes=self.n_classes,
            n_bands=self.n_bands,
            spatial_dropout=0.1,  # Default values
            time2_kernel=self.time2_kernel,
            time2_dropout=0.95,
            moe_num_experts=self.moe_num_experts,
            moe_top_k=self.moe_top_k,
            moe_gate_hidden=None,
            moe_aux_coef=self.moe_aux_coef,
            moe_switch = self.moe_switch,
            n_spatial_filters=self.n_spatial_filters,
            n_time1_filters=self.n_time1_filters,
            n_time2_filters=self.n_time2_filters
        )
        
        # Copy the state dict (weights and biases)
        new_model.load_state_dict(self.state_dict())
        
        # Move to the same device as the original model
        device = next(self.parameters()).device
        new_model = new_model.to(device)
        
        # Copy non-parameter attributes
        new_model.last_aux_loss = None
        new_model.moe_aux_coef = self.moe_aux_coef
        
        return new_model


