
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from flow_matching.solver import Solver, ODESolver
from flow_matching.utils import ModelWrapper

import einops
import lightning as L
import torch
from torch import Tensor, nn
from src.utils.probe_sampling import set_values_by_indices

class MaskedAutoEncoder(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: Callable,
        **kwargs
    ):
        super().__init__()
        self.model: torch.nn.Module = model()
        self.loss_fn = loss_fn

    def train_step(self, batch):
        
        x_1 = batch['field']
        mask = get_mask(x_1, batch['probe_idcs'])
        x_0 = torch.zeros_like(x_1, device=x_1.device)
        x_0[mask] = x_1[mask]

        # dummy t
        t = torch.zeros((x_1.shape[0]), device=x_1.device)

        # forward
        y_hat = self.model(x_0, t, **batch)
        
        loss = self.loss_fn(y_hat, x_1)

        return loss
    
    def inference(self, batch, **kwargs):
        
        x_1 = batch['field']
        mask = get_mask(x_1, batch['probe_idcs'])
        x_0 = torch.zeros_like(x_1, device=x_1.device)
        x_0[mask] = x_1[mask]
        
        t = torch.zeros((x_1.shape[0]), device=x_1.device)

        # forward
        y_hat = self.model(x_0, t, **batch)
        
        return y_hat
    
    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
