
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Dict, Any
from tqdm.auto import tqdm
from ops import LearnableQ
from models import ZDenoiser, RDenoiser, GatePosterior, GatePrior, BaselineActionDenoiser
from losses import Diffusion1D, dirichlet_kl, mse
from data import make_dataloaders, make_data_splits
import math, time

class SkillMoETrainer(nn.Module):
    def __init__(self, d, K, s_dim, a_dim, T=64, use_adapters=True, top_k=2, device='cpu'):
        super().__init__()
        self.device = device
        self.K = K
        self.d = d
        # Decoder Q (orthonormal via QR retraction)
        self.Qdec = LearnableQ(s_dim, d, K)
        # Denoisers for z (and r optional)
        self.zdiff = Diffusion1D(dim=K, T=T, device=device)
        self.zden = ZDenoiser(z_dim=K, s_dim=s_dim, hidden=128, K=K, use_adapters=use_adapters, top_k=top_k, rank=4)

        # Gate posterior and global usage
        self.gate_post = GatePosterior(s_dim=s_dim, a_dim=a_dim, K=K, hidden=128)
        self.gate_prior = GatePrior(s_dim=s_dim, K=K, hidden=128)
        self.alpha_global = nn.Parameter(torch.ones(K) * 2.0)  # q(vartheta)=Dir(softplus(alpha_global))

        # hyperparams for priors
        self.alpha0 = 0.5
        self.kappa = 20.0
        self.alpha_global_prior = 2.0  # p(vartheta)=Dir(alpha * 1)

    def decode(self, s, z0, g, return_Q=False):
        Q = self.Qdec(s)
        a_hat = (Q @ (g * z0).t()).t()
        if return_Q:
            return a_hat, Q
        return a_hat

    def z_from_action(self, S, a, g, eps=1e-4):
        """Project action to z via whitened decoder: z = (Q^T a) / (g + eps)."""
        Q = self.Qdec(S)                          # (B, d, k)
        # a: (B, d)
        proj = torch.matmul(Q.transpose(1, 2), a.unsqueeze(-1)).squeeze(-1)  # (B, k)
        z = proj / (g + eps)
        return z

    def forward(self, batch: Dict[str, torch.Tensor], opt: torch.optim.Optimizer, Tsteps=64, lambda_ortho=1e-4, lambda_gate_align=0.1):
        S = batch['S'].to(self.device)   # (B, T, s_dim)
        A = batch['A'].to(self.device)   # (B, T, d)
        Bsz, T, _ = S.shape
        s_dim = S.shape[-1]
        # reshape to (B*T, ...)
        Sflat = S.reshape(Bsz*T, -1)
        Aflat = A.reshape(Bsz*T, -1)

        # Posterior gate q(g|s,a) = Dir(beta_hat); use mean as routing signal
        beta_hat = self.gate_post(Sflat, Aflat)
        dist_g = torch.distributions.Dirichlet(beta_hat)
        try:
            g_sample = dist_g.rsample()
        except NotImplementedError:
            # MPS may not implement Dirichlet rsample; fallback to CPU sampling
            g_sample = torch.distributions.Dirichlet(beta_hat.detach().cpu()).sample().to(self.device)
        g_mean = beta_hat / beta_hat.sum(dim=-1, keepdim=True)

        # Project to z0 targets from actions
        z0 = self.z_from_action(S, Aflat, g_mean).detach()  # stop grad through target
        # Diffusion training on z
        t = torch.randint(0, self.zdiff.T, (Bsz*T,), device=self.device)
        z_t, eps = self.zdiff.q_sample(z0, t)
        eps_hat = self.zden(z_t, t, Sflat, g=g_sample)
        loss_diff_z = F.mse_loss(eps_hat, eps)


        # Dirichlet KLs for sticky gates
        # Global usage posterior q(vartheta)=Dir(softplus(alpha_global))
        alpha_q_v = F.softplus(self.alpha_global) + 1e-3
        alpha_p_v = torch.ones_like(alpha_q_v) * self.alpha_global_prior
        kl_v = dirichlet_kl(alpha_q_v.unsqueeze(0), alpha_p_v.unsqueeze(0)).mean()

        # Sequential KLs KL( q(g_t) || Dir(kappa g_{t-1} + alpha0 vartheta) )
        vartheta_mean = alpha_q_v / alpha_q_v.sum()
        # shape manip
        beta_hat_seq = beta_hat.reshape(Bsz, T, -1)
        g_mean_seq = g_mean.reshape(Bsz, T, -1)
        # initial t=0 KL
        prior0 = self.alpha0 * vartheta_mean.unsqueeze(0).unsqueeze(0).expand(Bsz, 1, -1)
        kl_g0 = dirichlet_kl(beta_hat_seq[:,0,:], prior0[:,0,:]).mean()
        # t>=1 KLs
        priors = []
        for tstep in range(1, T):
            gp = g_mean_seq[:, tstep-1, :]
            prior_t = self.kappa * gp + self.alpha0 * vartheta_mean.unsqueeze(0)
            priors.append(prior_t)
        if len(priors) > 0:
            priors_cat = torch.stack(priors, dim=1)
            kl_gt = dirichlet_kl(beta_hat_seq[:,1:,:].reshape(Bsz*(T-1), -1),
                                 priors_cat.reshape(Bsz*(T-1), -1)).mean()
        else:
            kl_gt = torch.tensor(0.0, device=self.device)

        # Supervised gate alignment to ground-truth gates (if provided)
        if 'G' in batch:
            Gtrue = batch['G'].to(self.device).reshape(Bsz*T, -1)
            kl_ga = (Gtrue.add(1e-8) * (Gtrue.add(1e-8).log() - g_mean.add(1e-8).log())).sum(dim=-1).mean()
        else:
            kl_ga = torch.tensor(0.0, device=self.device)


        # Optional gate-alignment term (supervised)
        kl_gate_align = lambda_gate_align * kl_ga

        Qmat = self.Qdec(Sflat)
        # A_proj = (Qmat @ (Qmat.t() @ Aflat.t())).t()
        coeff = torch.matmul(Qmat.transpose(1, 2), Aflat.unsqueeze(-1))   # (B, k, 1)
        a_proj = torch.matmul(Qmat, coeff).squeeze(-1)                # (B, d)
        proj_res = Aflat - a_proj
        loss_recon = (proj_res ** 2).sum(dim=-1).mean()
        loss = loss_diff_z + loss_recon + kl_v + kl_g0 + kl_gt + kl_gate_align 

        # Backprop & step
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.parameters(), 1.0)
        opt.step()

        
        with torch.no_grad():
            # Predict x0 for z (and r if enabled) from current noisy pairs
            z0_hat = self.zdiff.predict_x0(z_t, t, eps_hat)
            Qmat = self.Qdec(Sflat)
            coeff = (g_sample * z0_hat).unsqueeze(-1)      # (B, k, 1)
            Ahat = torch.matmul(Qmat, coeff).squeeze(-1)   # (B, d)
            mse_a = F.mse_loss(Aflat, Ahat).item()
            g_seq = g_mean.reshape(Bsz, T, -1)
            g_argmax = torch.argmax(g_seq, dim=-1)
            switches = (g_argmax[:,1:] != g_argmax[:,:-1]).float().mean().item()
            usage = g_seq.mean(dim=(0,1))
            entropy = (- (usage * (usage+1e-9).log()).sum() / math.log(self.K)).item()
        logs = dict(total_loss=float(loss.item()),
                    diff_z=float(loss_diff_z.item()),
                    kl_v=float(kl_v.item()),
                    kl_g0=float(kl_g0.item()),
                    kl_gt=float(kl_gt.item()),
                    kl_gate_align=float(kl_gate_align.item()),
                    mse_action=mse_a,
                    switch_rate=switches,
                    usage_entropy=entropy)
        return logs

class BaselineActionDiffusion(nn.Module):
    """Standard action-space diffusion baseline."""
    def __init__(self, a_dim, s_dim, T=64, device='cpu'):
        super().__init__()
        from models import BaselineActionDenoiser
        self.device = device
        self.diff = Diffusion1D(dim=a_dim, T=T, device=device)
        self.den = BaselineActionDenoiser(a_dim=a_dim, s_dim=s_dim, hidden=128)

    def forward(self, batch, opt):
        S = batch['S'].to(self.device)
        A = batch['A'].to(self.device)
        Bsz, T, _ = S.shape
        Sflat = S.reshape(Bsz*T, -1)
        Aflat = A.reshape(Bsz*T, -1)
        t = torch.randint(0, self.diff.T, (Bsz*T,), device=self.device)
        A_t, eps = self.diff.q_sample(Aflat, t)
        eps_hat = self.den(A_t, t, Sflat)
        loss = F.mse_loss(eps_hat, eps)
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.parameters(), 1.0)
        opt.step()
        return dict(total_loss=float(loss.item()), diff=float(loss.item()))

@torch.no_grad()
def _eval_skillmoe(model: SkillMoETrainer, dl: DataLoader):
    device = model.device
    mse_action_proj = []
    mse_action_denoised = []
    gate_kl_true = []
    switch_rates = []
    for batch in dl:
        S = batch['S'].to(device)
        A = batch['A'].to(device)
        Gtrue = batch['G'].to(device)
        Bsz, T, _ = S.shape
        Sflat = S.reshape(Bsz*T, -1)
        Aflat = A.reshape(Bsz*T, -1)
        beta_hat = model.gate_post(Sflat, Aflat)
        g_mean = beta_hat / beta_hat.sum(dim=-1, keepdim=True)
        # Projection-based recon via Q
        z0 = model.z_from_action(Aflat, g_mean)
        Qmat = model.Qdec()
        Ahat = (Qmat @ (g_mean * z0).t()).t()
        mse_action_proj.append(F.mse_loss(Aflat, Ahat).item())
        # Gate KL(G_true || g_mean)
        Gflat = Gtrue.reshape(Bsz*T, -1)
        kl = (Gflat.add(1e-8) * (Gflat.add(1e-8).log() - g_mean.add(1e-8).log())).sum(dim=-1).mean().item()
        gate_kl_true.append(kl)
        # Switch rate from predicted gates
        g_seq = g_mean.reshape(Bsz, T, -1)
        g_argmax = torch.argmax(g_seq, dim=-1)
        switches = (g_argmax[:,1:] != g_argmax[:,:-1]).float().mean().item()
        switch_rates.append(switches)

        # Denoised action MSE via predict_x0 (comparable to baseline)
        # Sample g and noise timesteps, predict x0 in z (and r if enabled), reconstruct A
        dist_g = torch.distributions.Dirichlet(beta_hat)
        try:
            g_sample = dist_g.rsample()
        except NotImplementedError:
            g_sample = torch.distributions.Dirichlet(beta_hat.detach().cpu()).sample().to(device)
        t = torch.randint(0, model.zdiff.T, (Bsz*T,), device=device)
        z_t, eps = model.zdiff.q_sample(z0, t)
        eps_hat = model.zden(z_t, t, Sflat, g=g_sample)
        z0_hat = model.zdiff.predict_x0(z_t, t, eps_hat)
        Ahat_den = (Qmat @ (g_sample * z0_hat).t()).t()
        mse_action_denoised.append(F.mse_loss(Aflat, Ahat_den).item())
    return dict(
        mse_action_proj=float(sum(mse_action_proj)/len(mse_action_proj)) if mse_action_proj else float('nan'),
        mse_action_denoised=float(sum(mse_action_denoised)/len(mse_action_denoised)) if mse_action_denoised else float('nan'),
        gate_kl_true=float(sum(gate_kl_true)/len(gate_kl_true)) if gate_kl_true else float('nan'),
        switch_rate=float(sum(switch_rates)/len(switch_rates)) if switch_rates else float('nan'),
    )

@torch.no_grad()
def _eval_baseline(model: BaselineActionDiffusion, dl: DataLoader):
    device = model.device
    mse_action_denoised = []
    for batch in dl:
        S = batch['S'].to(device)
        A = batch['A'].to(device)
        Bsz, T, _ = S.shape
        Sflat = S.reshape(Bsz*T, -1)
        Aflat = A.reshape(Bsz*T, -1)
        # one-step denoise eval mirroring training
        t = torch.randint(0, model.diff.T, (Bsz*T,), device=device)
        A_t, eps = model.diff.q_sample(Aflat, t)
        eps_hat = model.den(A_t, t, Sflat)
        A0_hat = model.diff.predict_x0(A_t, t, eps_hat)
        mse_action_denoised.append(F.mse_loss(Aflat, A0_hat).item())
    return dict(mse_action_denoised=float(sum(mse_action_denoised)/len(mse_action_denoised)) if mse_action_denoised else float('nan'))


def train_and_eval(config: Dict[str, Any]):
    # Prefer CUDA, then MPS, then CPU
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')

    # Data splits
    ds, dl_train, dl_val, dl_test = make_data_splits(
        batch_size=config.get('batch_size', 16),
        val_batch_size=config.get('val_batch_size', None),
        test_batch_size=config.get('test_batch_size', None),
        val_ratio=config.get('val_ratio', 0.15),
        test_ratio=config.get('test_ratio', 0.15),
        split_seed=config.get('seed', 0) + 999,
        num_tasks=config.get('num_tasks', 3),
        seq_len=config.get('seq_len', 40),
        num_seq_per_task=config.get('num_seq_per_task', 120),
        d=config.get('d', 6),
        K=config.get('K', 4),
        seed=config.get('seed', 0),
    )
    d = ds.d; K = ds.K; s_dim = ds.s_dim; a_dim = ds.a_dim

    results = {}
    progress_bar = bool(config.get('progress_bar', True))
    sanity_every = int(config.get('sanity_every', 0))  # 0 disables

    model = SkillMoETrainer(d=d, K=K, s_dim=s_dim, a_dim=a_dim,
                            T=config.get('T', 64),
                            use_adapters=True,
                            top_k=config.get('top_k', 2),
                            device=device).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=config.get('lr', 3e-4), weight_decay=config.get('weight_decay', 1e-5))

    # Training
    epochs = config.get('epochs', 10)
    lambda_ortho = config.get('lambda_ortho', 1e-4)
    lambda_gate_align = config.get('lambda_gate_align', 0.1)
    model.train()
    for epoch in range(epochs):
        iterator = dl_train
        if progress_bar:
            iterator = tqdm(dl_train, desc=f"SkillMoE epoch {epoch+1}/{epochs}", leave=False)
        for batch in iterator:
            _ = model(batch, opt, lambda_ortho=lambda_ortho, lambda_gate_align=lambda_gate_align)
        if sanity_every > 0 and ((epoch + 1) % sanity_every == 0 or (epoch + 1) == epochs):
            model.eval()

            val_metrics = _eval_skillmoe(model, dl_val)
            # brief sanity log
            print(f"[sanity] SkillMoE epoch {epoch+1}: {val_metrics}")
            model.train()

    # Evaluation on test
    model.eval()

    test_metrics = _eval_skillmoe(model, dl_test)
    results['SkillMoE'] = test_metrics

    return results
