import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

from model.jepa_modules import JEPAEncoder, JEPAPredictor
from model.utils import Decoder

class RJEPA(nn.Module):

    def __init__(self, args):
        
        super().__init__()

        self.use_recon = args.eta > 0
        self.use_jepa = args.tau > 0

        if not self.use_recon and not self.use_jepa :
            raise NotImplementedError

        self.task = args.task
        self.ts = args.ts

        self.history_only = args.get("mask_history", False)
        if self.history_only :
            print("note that history masking is used in temporal encoder!")

        self.learn_std = args.get("learn_std", True)

        self.context_encoder = JEPAEncoder(args)
        self.predictor = JEPAPredictor(args)

        self.mask_fill_mode = args.get("mask_fill_mode", "zero")
        self.encoder_mode = args.get("encoder_mode", "temporal")

        self.decoder = Decoder(args) if self.use_recon else None # JEPA-regularizer
    
        if self.use_jepa > 0 :
            num_steps = args.get("num_steps", 10000)
            self.momentum_scheduler = iter(torch.linspace(args.init_momentum, 1.0, num_steps))
            self.target_encoder = deepcopy(self.context_encoder)
            for p in self.target_encoder.parameters() :
                p.requires_grad = False
        else :
            self.target_encoder = None

    def update_target_encoder(self) :
        with torch.no_grad():
            m = self.momentum_scheduler if isinstance(self.momentum_scheduler, float) else next(self.momentum_scheduler)
            for param_q, param_k in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):
                param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
        return m

    def forward_context_encoder(self, context_obs, context_times, full_times, temporal_mask, spatial_mask) :
        C = self.context_encoder(context_obs, context_times, temporal_mask, spatial_mask, history_only=self.history_only)
        Z_pred, alphas = self.predictor(C, full_times, temporal_mask)
        return Z_pred, alphas, C

    def forward_target_encoder(self, full_obs, full_times) :
        with torch.no_grad() :
            Z = self.target_encoder(full_obs, full_times, history_only=self.history_only)
            Z = F.layer_norm(Z, (Z.size(-1),))
        return Z

    def forward_predictor(self, context, times, context_mask) :
        (means, stds), alphas = self.predictor(context, times, context_mask)
        return (means, stds), alphas

    def drop_temporal_masking(self, full_obs, full_times, temporal_mask):
        """
        full_obs:      [B, T, M] or [B, T, N, D]
        full_times:    [B, T]
        temporal_mask: [B, T] boolean
        Returns:
        context_obs:   [B, t, M] or [B, t, N, D]
        context_times: [B, t]
        """
        B, T = temporal_mask.shape

        num_valid = temporal_mask.sum(dim=1)
        assert torch.all(num_valid == num_valid[0]), f"All batches must have same # of valid timesteps, got {num_valid}"
        t = int(num_valid[0])

        if full_obs.ndim == 3:
            _, _, N = full_obs.shape
            mask_exp = temporal_mask.unsqueeze(-1).expand(-1, -1, N)   # [B, T, N]
            context_obs = full_obs[mask_exp].view(B, t, N)
        elif full_obs.ndim == 4:
            _, _, N, D = full_obs.shape
            mask_exp = temporal_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, N, D)  # [B, T, N, D]
            context_obs = full_obs[mask_exp].view(B, t, N, D)
        else:
            raise ValueError(f"Unsupported full_obs.ndim={full_obs.ndim}")

        context_times = full_times[temporal_mask].view(B, t)  # [B*t] -> [B, t]

        return context_obs, context_times

    def forward(self, full_obs, full_times, temporal_mask, spatial_mask, n_samples=3) :
        """
        full_obs:      [B, T, M] or [B, T, N, D]
        full_times:    [B, T]
        temporal_mask: [B, T] boolean
        spatial_mask : [B, M] or [B, T, N]
        """        
        
        full_times = self.ts * full_times

        context_obs, context_times = self.drop_temporal_masking(full_obs, full_times, temporal_mask)
        
        O_target = self.forward_target_encoder(full_obs, full_times) if self.target_encoder else None
        X_mean_var, alphas, _ = self.forward_context_encoder(context_obs, context_times, full_times, temporal_mask, spatial_mask)

        full_times = full_times[..., None]
        KLs = 0.5 * (alphas[:, :-1].pow(2) * (full_times[:, 1:] - full_times[:, :-1]))

        X_mean, X_std = X_mean_var
        X_mean_var = (X_mean, X_std) if self.learn_std else (X_mean, 0.1 * torch.ones_like(X_std))

        if self.decoder :

            Z = torch.randn(size=(n_samples, *X_mean.size())).to(X_mean.device)
            X = X_mean + X_std * Z
            
            obs_mean = self.decoder(X)
            obs_var = 0.01 * torch.ones_like(obs_mean)

            obs_mean, obs_var = obs_mean.transpose(0, 1), obs_var.transpose(0, 1)
            return O_target, X_mean_var, KLs, (obs_mean, obs_var)
        
        else :
            return O_target, X_mean_var, KLs