
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

import einops
import lightning as L
import torch
from torch import Tensor, nn
# flow_matching
from flow_matching.path.scheduler import CondOTScheduler, Scheduler
from flow_matching.path import AffineProbPath

from src.utils.probe_sampling import set_values_by_indices

class Autoregression(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        fm_scheduler: Scheduler,
        loss_fn: Callable,
        std: float,
        n_steps: int,
    ):
        super().__init__()
        self.model: torch.nn.Module = model()
        self.path = AffineProbPath(scheduler=fm_scheduler)
        self.loss_fn = loss_fn
        self.std = std
        self.n_steps = n_steps

    # https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb
    def train_step(self, batch):
        
        '''
        follow
        Stochastic interpolants with data-dependent couplings
        http://arxiv.org/abs/2310.03725
        '''

        x_1 = batch['field'] # (b T h w c)
        
        x_1_condition = x_1[:, :-1]
        x_1_target = x_1[:, -1:]
        mask = get_mask(x_1_condition, batch['probe_idcs']) # (b T h w)

        new_batch = dict(
            probe_pos = batch['probe_pos'],
            probe_field = batch['probe_field'][:,-1:],
        )

        # forward
        x_t_next = self.model(x_1_condition, torch.tensor([0.], device=x_1.device, dtype=x_1.dtype), **new_batch)
        
        loss = self.loss_fn(y_pred=x_t_next[:,-1:,:,:,:2], y_true=x_1_target[:,-1:,:,:,:2], mask=mask)
        
        return loss
    
    def inference(self, batch, n_rollouts=16, **kwargs):
        
        x_1 = batch['field'] # (b T h w c)
        
        x_1_condition = x_1[:, :-1]
        
        new_batch = dict(
            probe_pos = batch['probe_pos'],
            probe_field = batch['probe_field'][:,-1:],
        )
        
        sol = self.model(x_1_condition, torch.tensor([0.], device=x_1.device, dtype=x_1.dtype), **new_batch)
        
        return sol
    
    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
