
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from src.utils.probe_sampling import set_values_by_indices

class ARDiffusion(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
    ):
        super().__init__()
        self.model: torch.nn.Module = model

    def train_step(self, batch):
        
        d = batch["field"]
        # BEGIN custom code
        # bring to shape B S C W H
        d = einops.rearrange(d, 'B S W H C -> B S C W H')
        # END custom code
        
        inputSteps = 2
        cond = []
        for i in range(inputSteps):
            cond += [d[:,i:i+1]] # collect input steps
        conditioning = torch.concat(cond, dim=2) # combine along channel dimension
        data = d[:, inputSteps:inputSteps+1]

        # input shape (both inputs): B S C W H (D) -> output shape (both outputs): B S nC W H (D)
        noise, predictedNoise = self.model(conditioning=conditioning, data=data)

        loss = F.smooth_l1_loss(noise, predictedNoise)
        
        return loss
    
    def inference(self, batch, numSamples=1, **kwargs):
        d = batch["field"].repeat(numSamples,1,1,1,1) # reuse batch dim for samples
        # BEGIN custom code
        # bring to shape B S C W H
        d = einops.rearrange(d, 'B S W H C -> B S C W H')
        # END custom code

        prediction = torch.zeros_like(d, device=d.device)
        inputSteps = 2

        for i in range(inputSteps): # no prediction of first steps
            prediction[:,i] = d[:,i]

        for i in range(inputSteps, d.shape[1]):
            cond = []
            for j in range(inputSteps,0,-1):
                cond += [prediction[:, i-j : i-(j-1)]] # collect input steps
            cond = torch.concat(cond, dim=2) # combine along channel dimension

            result = self.model.inference(conditioning=cond, data=d[:,i-1:i]) # auto-regressive inference
            # BEGIN why replace?
            # result[:,:,-len(simParams):] = d[:,i:i+1,-len(simParams):] # replace simparam prediction with true values
            # result[:,:,:] = d[:,i:i+1,:] # replace simparam prediction with true values
            # END
            prediction[:,i:i+1] = result

        # prediction = torch.reshape(prediction, (numSamples, -1, d.shape[1], d.shape[2], d.shape[3], d.shape[4]))
        # BEGIN custom code
        # bring to shape B S W H C
        prediction = einops.rearrange(prediction, 'B S C W H -> B S W H C')
        # END custom code
        return prediction
    
    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
