# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Optional

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.cli import instantiate_class

from pdearena import utils
from pdearena.data.utils import PDEDataConfig
from pdearena.modules.loss import CustomMSELoss, ScaledLpLoss
from pdearena.rollout import rollout2d, rollout3d_maxwell

from .registry import MODEL_REGISTRY
from .registry import COND_MODEL_REGISTRY

logger = utils.get_logger(__name__)
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

class MoE(nn.Module):
    """
    Hard (top-1) gating Mixture of Experts with a Gumbel-Softmax 
    straight-through estimator for training.
    """
    def __init__(
        self, 
        input_dim: int, 
        hidden_dim: int, 
        num_experts: int, 
        lb_coef: float = 1.0,
        temperature: float = 1.0,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.lb_coef = lb_coef
        self.temperature = temperature

        # Gating network: projects each token's features to logits over experts
        self.gate = nn.Linear(input_dim, num_experts)

        # Expert networks
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim),
            )
            for _ in range(num_experts)
        ])

        # Will store the load-balancing loss after each forward pass
        self.load_balance_loss = torch.tensor(0.0)
        # Initialize gating_probs to None
        self.gating_probs = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [N, D] where N = # tokens, D = input_dim.
        Returns:
            output: [N, D], same shape as x.
        """
        # 1) Compute gating logits
        gating_logits = self.gate(x)  # [N, num_experts]
        
        if self.training:
            # During training, add the uniform noise to the gating logits to encourage diversity among the experts
            gating_logits += torch.randn_like(gating_logits) * self.lb_coef 

        self.gating_probs = gating_logits.detach()

        if self.training: 
            # 2) Sample from Gumbel-Softmax distribution to get 'soft' gates
            gumbel_noise = -torch.log(
                -torch.log(torch.rand_like(gating_logits) + 1e-8) + 1e-8
            )
            gating_logits_gumbel = gating_logits + gumbel_noise
            gating_probs = F.softmax(gating_logits_gumbel / self.temperature, dim=-1)
            
            # 3) Straight-through trick:
            #    - Hard one-hot in forward pass
            #    - But keep the soft distribution for gradient flow
            chosen_experts = gating_probs.argmax(dim=-1)  # [N]
            one_hot = F.one_hot(chosen_experts, self.num_experts).float()
            gating_probs_st = one_hot + gating_probs - gating_probs.detach()
        else:
            # In inference (or eval mode), pick top-1 via argmax (no noise)
            chosen_experts = gating_logits.argmax(dim=-1)  # [N]
            gating_probs_st = F.one_hot(chosen_experts, self.num_experts).float()

        try:
        # Store gating_probs for logging (detach to avoid tracking gradients)
            self.gating_probs = gating_probs.detach()
        except:
            pass
        # 4) Compute expert outputs
        # Stack expert outputs: [N, num_experts, D]
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)

        # 5) Weight expert outputs by gating probabilities
        gating_probs_st = gating_probs_st.unsqueeze(-1)  # [N, num_experts, 1]
        output = (gating_probs_st * expert_outputs).sum(dim=1)  # [N, D]

        # 6) Load-balancing loss (only in training mode)
        lb_loss = torch.tensor(0.0, device=x.device)
        self.load_balance_loss = lb_loss

        return output




class MoEWrapper(nn.Module):
    """
    Wraps the original PDE model with a Hard (top-1) MoE layer in front.
    Instead of flattening everything, treat each (t, h, w) as an individual token
    so gating is done per-location/per-time rather than for the entire sample.
    """
    def __init__(self, 
                 base_model: nn.Module, 
                 input_dim: int, 
                 hidden_dim: int, 
                 num_experts: int,
                 lb_coef: float = 1e-2):
        """
        Args:
            base_model: your PDE model (expects [B, T, C, H, W]).
            input_dim: channels/features dimension (C) used for MoE gating.
            hidden_dim: hidden layer size inside each expert’s MLP.
            num_experts: number of experts.
            lb_coef: coefficient for load-balancing loss (training only).
        """
        super().__init__()
        self.base_model = base_model
        self.moe_layer = MoE(input_dim, hidden_dim, num_experts, lb_coef=lb_coef)

        # Expose the load-balance loss here as well
        self.load_balance_loss = torch.tensor(0.0)
        # Initialize gating_probs to None
        self.gating_probs = None

    def forward(self, *args, **kwargs):
        """
        Expects x in [B, T, C, H, W].
        We route each (t, h, w) location (token) through the MoE separately.
        """
        x = args[0]  # shape: [B, T, C, H, W]
        B, T, C, H, W = x.shape

        # 1) Reshape so that each (t, h, w) is a separate token
        #    We'll keep 'C' as the feature dim for gating.
        x_reshaped = x.view(B, -1)  # [B, 196608]
        # shape: [B*T*H*W, C]

        # 2) Forward through the MoE layer (hard routing)
        x_moe = self.moe_layer(x_reshaped)
        # shape: [B*T*H*W, C]
        # Store gating_probs from MoE layer
        self.gating_probs = self.moe_layer.gating_probs  # [N, num_experts]

        # 3) Reshape back to [B, T, C, H, W]
        x_moe = x_moe.view(B, T, C, H, W)

        # 4) Pass to your original PDE model
        out = self.base_model(x_moe, *args[1:], **kwargs)

        # 5) Expose the MoE's load-balancing loss so the training loop can access it
        self.load_balance_loss = self.moe_layer.load_balance_loss

        return out




def get_model(args, pde):
    use_moe = args.get("use_moe", False)
    if use_moe:
        print("Model Name: ", args.name)
        print("Using MoE")
        if args.name in MODEL_REGISTRY:
            _model = MODEL_REGISTRY[args.name].copy()
            _model["init_args"].update(
                dict(
                    n_input_scalar_components=pde.n_scalar_components,
                    n_output_scalar_components=pde.n_scalar_components,
                    n_input_vector_components=pde.n_vector_components,
                    n_output_vector_components=pde.n_vector_components,
                    time_history=args.time_history,
                    time_future=args.time_future,
                    activation=args.activation,
                )
            )
            base_model = instantiate_class(tuple(), _model)  # The “plain” PDE model

            # Wrap base_model with your MoE:
            # Customize input_dim, hidden_dim, num_experts, etc.
            # For example, if your input is [batch,5,64,64], then input_dim could be 5×64×64=20480:
            input_dim = 4 * 3 * 128 * 128
            hidden_dim = 512
            num_experts = 4
            model = MoEWrapper(base_model, input_dim, hidden_dim, num_experts) 
            return model 
        else:
            raise NotImplementedError(f"Model {args.name} not found in registry.") 
    else:
        print("Expert Model Name: ", args.name)
        if args.name in MODEL_REGISTRY:
            _model = MODEL_REGISTRY[args.name].copy()
            _model["init_args"].update(
                dict(
                    n_input_scalar_components=pde.n_scalar_components,
                    n_output_scalar_components=pde.n_scalar_components,
                    n_input_vector_components=pde.n_vector_components,
                    n_output_vector_components=pde.n_vector_components,
                    time_history=args.time_history,
                    time_future=args.time_future,
                    activation=args.activation,
                )
            )
            model = instantiate_class(tuple(), _model)  # The “plain” PDE model 
            return model 
        else:
            raise NotImplementedError(f"Model {args.name} not found in registry.") 



class PDEModel(LightningModule):
    def __init__(
        self,
        name: str,
        time_history: int,
        time_future: int,
        time_gap: int,
        max_num_steps: int,
        activation: str,
        criterion: str,
        lr: float,
        pdeconfig: PDEDataConfig,
        model: Optional[Dict] = None,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(ignore="pdeconfig")
        self.pde = pdeconfig
        if (self.pde.n_spatial_dim) == 3:
            self._mode = "3DMaxwell"
            assert self.pde.n_scalar_components == 0
            assert self.pde.n_vector_components == 2
        elif (self.pde.n_spatial_dim) == 2:
            self._mode = "2D"
        else:
            raise NotImplementedError(f"{self.pde}")

        self.model = get_model(self.hparams, self.pde)
        if criterion == "mse":
            self.train_criterion = CustomMSELoss()
        elif criterion == "scaledl2":
            self.train_criterion = ScaledLpLoss()
        else:
            raise NotImplementedError(f"Criterion {criterion} not implemented yet")

        self.val_criterions = {"mse": CustomMSELoss(), "scaledl2": ScaledLpLoss()}
        self.rollout_criterion = torch.nn.MSELoss(reduction="none")
        time_resolution = self.pde.trajlen
        # Max number of previous points solver can eat
        reduced_time_resolution = time_resolution - self.hparams.time_history
        # Number of future points to predict
        self.max_start_time = (
            reduced_time_resolution - self.hparams.time_future * self.hparams.max_num_steps - self.hparams.time_gap
        )

    def forward(self, *args):
        return self.model(*args)

    def train_step(self, batch):
        x, y = batch
        pred = self.model(x)
        # loss = self.train_criterion(pred, y)
        loss = self.train_criterion(pred, y) + self.model.load_balance_loss
        return loss, pred, y

    def eval_step(self, batch):
        x, y = batch
        pred = self.model(x)
        loss = {k: vc(pred, y) for k, vc in self.val_criterions.items()}
        return loss, pred, y

    def training_step(self, batch, batch_idx: int):
        loss, preds, targets = self.train_step(batch)

        if self._mode == "2D":
            scalar_loss = self.train_criterion(
                preds[:, :, 0 : self.pde.n_scalar_components, ...],
                targets[:, :, 0 : self.pde.n_scalar_components, ...],
            )

            if self.pde.n_vector_components > 0:
                vector_loss = self.train_criterion(
                    preds[:, :, self.pde.n_scalar_components :, ...],
                    targets[:, :, self.pde.n_scalar_components :, ...],
                )
            else:
                vector_loss = torch.tensor(0.0)
            self.log("train/loss", loss)
            self.log("train/scalar_loss", scalar_loss)
            self.log("train/vector_loss", vector_loss)
            
            # Log gating probabilities
            gating_probs = self.model.moe_layer.gating_probs  # [N, num_experts]
            if gating_probs is not None:
                # Example 1: Log the mean gating probability per expert
                mean_gating_probs = gating_probs.mean(dim=0)  # [num_experts]
                for expert_idx, mean_prob in enumerate(mean_gating_probs):
                    self.log(f"train/gating_prob_mean_expert_{expert_idx}", mean_prob)

                # Example 2: Log gating probabilities as histograms (requires specific loggers like WandB)
                # Uncomment the following lines if using a logger that supports histogram logging
                self.logger.experiment.log({
                    f"train/gating_probs_expert_{i}": wandb.Histogram(gating_probs[:, i].cpu())
                    for i in range(gating_probs.size(1))
                })
            return {
                "loss": loss,
                "scalar_loss": scalar_loss.detach(),
                "vector_loss": vector_loss.detach(),
            }
        else:
            raise NotImplementedError(f"{self._mode}")

    def training_epoch_end(self, outputs: List[Any]):
        # `outputs` is a list of dicts returned from `training_step()`
        for key in outputs[0].keys():
            if "loss" in key:
                loss_vec = torch.stack([outputs[i][key] for i in range(len(outputs))])
                mean, std = utils.bootstrap(loss_vec, 64, 1)
                self.log(f"train/{key}_mean", mean)
                self.log(f"train/{key}_std", std)

    def compute_rolloutloss2D(self, batch: Any):
        (u, v, cond, grid) = batch

        losses = []
        for start in range(
            0,
            self.max_start_time + 1,
            self.hparams.time_future + self.hparams.time_gap,
        ):
            end_time = start + self.hparams.time_history
            target_start_time = end_time + self.hparams.time_gap
            target_end_time = target_start_time + self.hparams.time_future * self.hparams.max_num_steps

            init_u = u[:, start:end_time, ...]
            if self.pde.n_vector_components > 0:
                init_v = v[:, start:end_time, ...]
            else:
                init_v = None

            pred_traj = rollout2d(
                self.model,
                init_u,
                init_v,
                grid,
                self.pde,
                self.hparams.time_history,
                self.hparams.max_num_steps,
            )
            targ_u = u[:, target_start_time:target_end_time, ...]
            if self.pde.n_vector_components > 0:
                targ_v = v[:, target_start_time:target_end_time, ...]
                targ_traj = torch.cat((targ_u, targ_v), dim=2)
            else:
                targ_traj = targ_u
            loss = self.rollout_criterion(pred_traj, targ_traj).mean(dim=(0, 2, 3, 4))
            losses.append(loss)
        loss_vec = torch.stack(losses, dim=0).mean(dim=0)
        return loss_vec

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        if dataloader_idx == 0:
            # one-step loss
            loss, preds, targets = self.eval_step(batch)
            if self._mode == "2D":
                loss["scalar_mse"] = self.val_criterions["mse"](
                    preds[:, :, 0 : self.pde.n_scalar_components, ...],
                    targets[:, :, 0 : self.pde.n_scalar_components, ...],
                )
                loss["vector_mse"] = self.val_criterions["mse"](
                    preds[:, :, self.pde.n_scalar_components :, ...],
                    targets[:, :, self.pde.n_scalar_components :, ...],
                )

                for k in loss.keys():
                    self.log(f"valid/loss/{k}", loss[k])
                return {f"{k}_loss": v for k, v in loss.items()}

            else:
                raise NotImplementedError(f"{self._mode}")

        elif dataloader_idx == 1:
            # rollout loss
            if self._mode == "2D":
                loss_vec = self.compute_rolloutloss2D(batch)
            else:
                raise NotImplementedError(f"{self._mode}")
            # summing across "time axis"
            loss = loss_vec.sum()
            loss_t = loss_vec.cumsum(0)
            chan_avg_loss = loss / (self.pde.n_scalar_components + self.pde.n_vector_components)
            self.log("valid/unrolled_loss", loss)
            return {
                "unrolled_loss": loss,
                "loss_timesteps": loss_t,
                "unrolled_chan_avg_loss": chan_avg_loss,
            }

    def validation_epoch_end(self, outputs: List[Any]):
        if len(outputs) > 1:
            if len(outputs[0]) > 0:
                for key in outputs[0][0].keys():
                    if "loss" in key:
                        loss_vec = torch.stack([outputs[0][i][key] for i in range(len(outputs[0]))])
                        mean, std = utils.bootstrap(loss_vec, 64, 1)
                        self.log(f"valid/{key}_mean", mean)
                        self.log(f"valid/{key}_std", std)

            if len(outputs[1]) > 0:
                unrolled_loss = torch.stack([outputs[1][i]["unrolled_loss"] for i in range(len(outputs[1]))])
                loss_timesteps_B = torch.stack([outputs[1][i]["loss_timesteps"] for i in range(len(outputs[1]))])
                loss_timesteps = loss_timesteps_B.mean(0)

                for i in range(self.hparams.max_num_steps):
                    self.log(f"valid/intime_{i}_loss", loss_timesteps[i])

                mean, std = utils.bootstrap(unrolled_loss, 64, 1)
                self.log("valid/unrolled_loss_mean", mean)
                self.log("valid/unrolled_loss_std", std)

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        if dataloader_idx == 0:
            loss, preds, targets = self.eval_step(batch)
            if self._mode == "2D":
                loss["scalar_mse"] = self.val_criterions["mse"](
                    preds[:, :, 0 : self.pde.n_scalar_components, ...],
                    targets[:, :, 0 : self.pde.n_scalar_components, ...],
                )
                loss["vector_mse"] = self.val_criterions["mse"](
                    preds[:, :, self.pde.n_scalar_components :, ...],
                    targets[:, :, self.pde.n_scalar_components :, ...],
                )

                self.log("test/loss", loss)
                return {f"{k}_loss": v for k, v in loss.items()}
            else:
                raise NotImplementedError(f"{self._mode}")

        elif dataloader_idx == 1:
            if self._mode == "2D":
                loss_vec = self.compute_rolloutloss2D(batch)
            else:
                raise NotImplementedError(f"{self._mode}")
            # summing across "time axis"
            loss = loss_vec.sum()
            loss_t = loss_vec.cumsum(0)
            self.log("test/unrolled_loss", loss)
            # self.log("valid/normalized_unrolled_loss", loss)
            return {
                "unrolled_loss": loss,
                "loss_timesteps": loss_t,
            }

    def test_epoch_end(self, outputs: List[Any]):
        assert len(outputs) > 1
        if len(outputs[0]) > 0:
            for key in outputs[0][0].keys():
                if "loss" in key:
                    loss_vec = torch.stack([outputs[0][i][key] for i in range(len(outputs[0]))])
                    mean, std = utils.bootstrap(loss_vec, 64, 1)
                    self.log(f"test/{key}_mean", mean)
                    self.log(f"test/{key}_std", std)
        if len(outputs[1]) > 0:
            unrolled_loss = torch.stack([outputs[1][i]["unrolled_loss"] for i in range(len(outputs[1]))])
            loss_timesteps_B = torch.stack([outputs[1][i]["loss_timesteps"] for i in range(len(outputs[1]))])
            loss_timesteps = loss_timesteps_B.mean(0)
            for i in range(self.hparams.max_num_steps):
                self.log(f"test/intime_{i}_loss", loss_timesteps[i])

            mean, std = utils.bootstrap(unrolled_loss, 64, 1)
            self.log("test/unrolled_loss_mean", mean)
            self.log("test/unrolled_loss_std", std)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr)
        return optimizer 

    def on_after_backward(self):
        # Call the parent's hook (optional, but good practice)
        super().on_after_backward()

        # Now the gradients have been computed, we can measure them
        # assuming your gating network is called "gate" inside `self.model.moe_layer`
        gate_grad_norm = 0.0
        for p in self.model.moe_layer.gate.parameters():
            if p.grad is not None:
                gate_grad_norm += (p.grad.data.norm(2).item())**2

        gate_grad_norm = gate_grad_norm**0.5

        # Log it as a Lightning scalar
        self.log("train/moe_gate_grad_norm", gate_grad_norm)

class Maxwell3DPDEModel(PDEModel):
    def compute_rolloutloss3D(self, batch: Any):
        d, h, _ = batch
        losses = []
        for start in range(
            0,
            self.max_start_time + 1,
            self.hparams.time_future + self.hparams.time_gap,
        ):
            end_time = start + self.hparams.time_history
            target_start_time = end_time + self.hparams.time_gap
            target_end_time = target_start_time + self.hparams.time_future * self.hparams.max_num_steps
            init_d = d[:, start:end_time]
            init_h = h[:, start:end_time]
            pred_traj = rollout3d_maxwell(
                self.model,
                init_d,
                init_h,
                self.hparams.time_history,
                self.hparams.max_num_steps,
            )
            targ_d = d[:, target_start_time:target_end_time]
            targ_h = h[:, target_start_time:target_end_time]
            targ_traj = torch.cat((targ_d, targ_h), dim=2)  # along channel
            loss = self.rollout_criterion(pred_traj, targ_traj).mean(dim=(0, 2, 3, 4, 5))
            losses.append(loss)
        loss_vec = torch.stack(losses, dim=0).mean(dim=0)
        return loss_vec

    def training_step(self, batch, batch_idx: int):
        loss, preds, targets = self.train_step(batch)

        if self._mode == "3DMaxwell":
            d_loss = self.train_criterion(preds[:, :, :3, ...], targets[:, :, :3, ...])
            h_loss = self.train_criterion(preds[:, :, 3:, ...], targets[:, :, 3:, ...])
            self.log("train/loss", loss)
            self.log("train/d_loss", d_loss)
            self.log("train/h_loss", h_loss)
            return {
                "loss": loss,
                "d_loss": d_loss,
                "h_loss": h_loss,
            }
        else:
            raise NotImplementedError(f"{self._mode}")

    def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        if dataloader_idx == 0:
            # one-step loss
            loss, preds, targets = self.eval_step(batch)
            if self._mode == "3DMaxwell":
                loss["d_field_mse"] = self.val_criterions["mse"](preds[:, :, :3, ...], targets[:, :, :3, ...])
                loss["h_field_mse"] = self.val_criterions["mse"](preds[:, :, 3:, ...], targets[:, :, 3:, ...])

                for k in loss.keys():
                    self.log("valid/loss", loss[k])
                return {f"{k}_loss": v for k, v in loss.items()}
            else:
                raise NotImplementedError(f"{self._mode}")

        elif dataloader_idx == 1:
            # rollout loss
            if self._mode == "3DMaxwell":
                loss_vec = self.compute_rolloutloss3D(batch)
            else:
                raise NotImplementedError(f"{self._mode}")
            # summing across "time axis"
            loss = loss_vec.sum()
            loss_t = loss_vec.cumsum(0)
            chan_avg_loss = loss / (self.pde.n_scalar_components + self.pde.n_vector_components)
            self.log("valid/unrolled_loss", loss)
            return {
                "unrolled_loss": loss,
                "loss_timesteps": loss_t,
                "unrolled_chan_avg_loss": chan_avg_loss,
            }

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        if dataloader_idx == 0:
            loss, preds, targets = self.eval_step(batch)
            if self._mode == "3DMaxwell":
                d_loss = self.val_criterions["mse"](preds[:, :, :3, ...], targets[:, :, :3, ...])
                h_loss = self.val_criterions["mse"](preds[:, :, 3:, ...], targets[:, :, 3:, ...])
                self.log("test/loss", loss)
                self.log("test/d_loss", d_loss)
                self.log("test/h_loss", h_loss)
                return {
                    "loss": loss,
                    "d_loss": d_loss,
                    "h_loss": h_loss,
                }
            else:
                raise NotImplementedError(f"{self._mode}")

        elif dataloader_idx == 1:
            if self._mode == "3DMaxell":
                loss_vec = self.compute_rolloutloss3D(batch)
            else:
                raise NotImplementedError(f"{self._mode}")
            # summing across "time axis"
            loss = loss_vec.sum()
            loss_t = loss_vec.cumsum(0)
            self.log("test/unrolled_loss", loss)
            # self.log("valid/normalized_unrolled_loss", loss)
            return {
                "unrolled_loss": loss,
                "loss_timesteps": loss_t,
            }
