
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

import einops
import lightning as L
import torch
from torch import Tensor, nn
# flow_matching
from flow_matching.path.scheduler import CondOTScheduler, Scheduler
from flow_matching.path import AffineProbPath

from src.utils.probe_sampling import set_values_by_indices

class FlowMatching(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        fm_scheduler: Scheduler,
        loss_fn: Callable,
        std: float,
        n_steps: int,
        concat_timestep: bool=False,
    ):
        super().__init__()
        self.model: torch.nn.Module = model()
        self.path = AffineProbPath(scheduler=fm_scheduler)
        self.loss_fn = loss_fn
        self.std = std
        self.n_steps = n_steps
        self.concat_timestep = concat_timestep

    # https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb
    def train_step(self, batch):
        
        '''
        follow
        Stochastic interpolants with data-dependent couplings
        http://arxiv.org/abs/2310.03725
        '''
        
        x_1 = batch['field'] # (b T h w c)
        b, T, h, w, c = batch['field'].shape
        
        x_1_condition = x_1[:, :-1]
        x_1_target = x_1[:, -1:]

        x_0_condition = torch.clone(x_1_condition)
        x_0_noise = torch.randn_like(x_1_target, device=x_1_target.device) * self.std
        x_0 = torch.cat([x_0_condition, x_0_noise], dim=1)

        mask = get_mask(x_1, batch['probe_idcs']) # (b T h w)
        x_0[mask] = x_1[mask]
        
        # sample time
        t = torch.rand(x_1_target.shape[0], device=x_1_target.device)

        # sample probability path
        path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1)
        x_t = path_sample.x_t
        
        # append mask to input "the model sees the mask"
        cat_list = [x_t, mask.to(dtype=x_1.dtype).unsqueeze(-1)]
        if self.concat_timestep:
            cat_list.append(einops.repeat(t, 'b -> b T h w 1', T=T, h=h, w=w))
        x_t = torch.cat(cat_list, dim=-1)
        
        # forward
        dx_t = self.model(x_t, path_sample.t, **batch)
        
        loss = self.loss_fn(y_pred=dx_t[:,-1:,:,:,:2], y_true=path_sample.dx_t[:,-1:,:,:,:2], mask=mask)
        
        return loss
    
    def inference(self, batch, n_steps=30, return_intermediates=False, **kwargs):
        
        x_1 = batch['field'] # (b T h w c)
        b, T, h, w, c = batch['field'].shape
        
        x_1_condition = x_1[:, :-1]
        x_1_target = x_1[:, -1:]

        x_0_condition = torch.clone(x_1_condition)
        x_0_noise = torch.randn_like(x_1_target, device=x_1_target.device) * self.std
        x_0 = torch.cat([x_0_condition, x_0_noise], dim=1)

        mask = get_mask(x_1, batch['probe_idcs']) # (b T h w)
        x_0[mask] = x_1[mask]
        
        time_grid = torch.linspace(0, 1, n_steps, device=x_1.device)  # sample times
        
        local_concat_timestep = self.concat_timestep
        
        class WrappedModel(ModelWrapper):
            def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
                # append mask to input: "the model sees the mask"
                cat_list = [x, mask.to(dtype=x_1.dtype).unsqueeze(-1)]
                if local_concat_timestep:
                    cat_list.append(einops.repeat(t, ' -> b T h w 1', b=b, T=T, h=h, w=w))
                x = torch.cat(cat_list, dim=-1)
                # forward
                dx_t = self.model(x, t, **batch)
                # remove mask
                dx_t[mask] = 0
                return dx_t[..., :2]
        wrapped_model = WrappedModel(self.model)

        solver = ODESolver(velocity_model=wrapped_model)
        sol = solver.sample(time_grid=time_grid, 
                                 x_init=x_0, 
                                 method='midpoint', 
                                 step_size=1. / n_steps, 
                                 return_intermediates=return_intermediates,
                                 **batch,
                                 **kwargs
                                )
        return sol
    
    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
