
import torch.nn as nn
import torch
import torch.nn.functional as F 

def sample_gumbel(logits, eps=1e-10):
    """
    Sample Gumbel noise for Gumbel-Softmax.
    """
    # Uniform random numbers like the logits
    U = torch.rand_like(logits)
    # Transform to gumbel noise
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax(logits, tau=1.0, hard=False):
    """
    Gumbel-Softmax trick for differentiable sampling of a categorical.
    - logits:  (batch_size, num_classes)
    - tau:     temperature parameter
    - hard:    if True, do the straight-through trick
    Returns a tensor of shape (batch_size, num_classes) that is:
      - mostly one-hot if 'hard=True' 
      - a probability distribution if 'hard=False',
    but still differentiable wrt. 'logits' in the backward pass.
    """
    # Add Gumbel noise
    gumbel_noise = sample_gumbel(logits)
    y = logits + gumbel_noise
    
    # Softmax for a "soft" probability
    y_soft = F.softmax(y / tau, dim=-1)
    
    if hard:
        # Straight-through trick: forward pass uses discrete one-hot,
        # backward pass uses continuous 'y_soft' for gradient flow.
        shape = y_soft.size()
        _, max_idx = y_soft.max(dim=-1, keepdim=True)
        y_hard = torch.zeros_like(y_soft).scatter_(-1, max_idx, 1.0)
        
        # Combine soft + hard
        y = y_hard.detach() - y_soft.detach() + y_soft
    else:
        y = y_soft

    return y


class HybridModel(nn.Module):
    """If you want a hybrid approach, you might add some 'residual' logic in here."""
    def __init__(self, base_model: nn.Module, fidelity: str, alpha: float = 0.5):
        super().__init__()
        self.base_model = base_model
        self.alpha = alpha
        if fidelity not in ["fine", "coarse", "xcoarse", "medium", "gt"]:
            raise ValueError(f"Invalid fidelity level: {fidelity}")
        self.fidelity = fidelity 

    def forward(self, batch): 
        fidelity = self.fidelity
        alpha = self.alpha
        x, y, fidelity_data = batch
        y_pred = self.base_model(x)
        coarse_data = fidelity_data.get(fidelity)
        if coarse_data is None: 
            print("Fidelity: ", fidelity)
            print("Fidelity Data: ", fidelity_data)
            raise ValueError("Fidelity not found in `fidelity_data`.") 
        x_coarse, y_coarse = coarse_data
        return alpha*y_pred + (1-alpha)*y_coarse  
    
class NeuralOperator(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.model = base_model
        self.fidelity = None 

    def forward(self, batch):
        x, y, fidelity_data = batch
        return self.model(x)


class GatingModel(nn.Module):
    def __init__(self, input_dim: int,
                    output_dim: int,
                    hidden_dim: int=1,
                    use_fidelity: bool=True):
        super(GatingModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.use_fidelity = use_fidelity
        if use_fidelity:
            self.fc3 = nn.Linear(hidden_dim+1, output_dim)
        else:
            self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, batch):
        x, y, fidelity_data = batch
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x)) 
        if self.use_fidelity:
            viscosity = fidelity_data.get("viscosity")
            x = torch.cat([x, viscosity], dim=1)
        x = self.fc3(x)
        return x

class MoEModel(nn.Module):
    def __init__(self, experts_list: nn.ModuleList, 
                 use_fidelity: bool=True,
                 tau: float=1.0,
                 hard: bool=False):
        super().__init__() 
        num_experts = len(experts_list)
        input_dim = 4*3*128*128 # Hardcoded for the time being
        self.num_experts = num_experts
        self.tau = tau
        self.hard = hard
        self.experts = experts_list 
        self.gate = GatingModel(input_dim, num_experts, use_fidelity=use_fidelity)

    def forward(self, batch): 
        gate_logits = self.gate(batch)  
        gate_probs = gumbel_softmax(gate_logits, tau=self.tau, hard=self.hard) 
        expert_outputs = [expert(batch) for expert in self.experts] 
        avg_expert_outputs = torch.stack(expert_outputs, dim=5)  
        gate_probs = gate_probs.view(gate_logits.shape[0], 1, 1, 1, 1, gate_logits.shape[-1])  
        out = torch.sum(avg_expert_outputs * gate_probs, dim=-1)   
        return out, gate_probs, expert_outputs
    
 