
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

import einops
import lightning as L
import torch
from torch import Tensor, nn
# flow_matching
from flow_matching.path.scheduler import CondOTScheduler, Scheduler
from flow_matching.path import AffineProbPath

from src.utils.probe_sampling import set_values_by_indices

class FlowMatching(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        fm_scheduler: Scheduler,
        loss_fn: Callable,
        std: float,
        n_steps: int,
    ):
        super().__init__()
        self.model: torch.nn.Module = model()
        self.path = AffineProbPath(scheduler=fm_scheduler)
        self.loss_fn = loss_fn
        self.std = std
        self.n_steps = n_steps

    # https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb
    def train_step(self, batch):
        
        '''
        follow
        Stochastic interpolants with data-dependent couplings
        http://arxiv.org/abs/2310.03725
        '''
        
        x_1 = batch['field'] # (b T h w c)
        x_0 = torch.randn_like(x_1, device=x_1.device) * self.std

        mask = get_mask(x_1, batch['probe_idcs']) # (b T h w)
        # set mask for future timesteps 0
        if self.n_steps > 0:
            mask[:, -self.n_steps:, :, :] = False
        x_0[mask] = x_1[mask]
        
        # sample time
        t = torch.rand(x_1.shape[0], device=x_1.device)

        # sample probability path
        path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1)

        x_t = path_sample.x_t
        
        # append mask to input "the model sees the mask"
        x_t = torch.cat([x_t, mask.to(dtype=x_1.dtype).unsqueeze(-1)], dim=-1)
        
        # forward
        dx_t = self.model(x_t, path_sample.t, **batch)
        
        loss = self.loss_fn(y_pred=dx_t[..., :-1], y_true=path_sample.dx_t, mask=mask)
        # residual = dx_t[..., :-1] - path_sample.dx_t
        # # now mask residuals as 
        # residual[mask] = 0.
        
        # loss = torch.pow(residual, 2).mean() 

        return loss
    
    def inference(self, batch, n_steps=30, return_intermediates=False, **kwargs):
        x_1 = batch['field']
        device = x_1.device
        x_0 = torch.randn_like(x_1, dtype=torch.float32, device=device) * self.std
        T = torch.linspace(0, 1, n_steps, device=device)  # sample times
        
        mask = get_mask(x_1, batch['probe_idcs'])
        # set mask for future timesteps 0
        if self.n_steps > 0:
            mask[:, -self.n_steps:, :, :] = False
        x_0[mask] = x_1[mask]
        
        class WrappedModel(ModelWrapper):
            def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
                # append mask to input: "the model sees the mask"

                # necessary for inference with batchsize > 1
                t = t.expand(x.shape[0])
                
                x = torch.cat([x, mask.to(dtype=x_1.dtype).unsqueeze(-1)], dim=-1)
                # forward
                dx_t = self.model(x, t, **batch)
                # remove mask
                dx_t[mask] = 0
                return dx_t[..., :-1]
        wrapped_model = WrappedModel(self.model)

        solver = ODESolver(velocity_model=wrapped_model)
        sol = solver.sample(time_grid=T, 
                                 x_init=x_0, 
                                 method='midpoint', 
                                 step_size=1. / n_steps, 
                                 return_intermediates=return_intermediates,
                                 **batch,
                                 **kwargs
                                )
        return sol

    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
