
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

import einops
import lightning as L
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from src.utils.probe_sampling import set_values_by_indices

class ARDiffusion(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
    ):
        super().__init__()
        self.model: torch.nn.Module = model

    def train_step(self, batch):
        
        d = batch["field"]
        # BEGIN custom code
        b, s, w, h, c = d.shape
        
        # bring to shape B S C W H
        d = einops.rearrange(d, 'B S W H C -> B S C W H')
        # END custom code
        
        assert s == 3, "needs to have 2 history, 1 to predict"
        inputSteps = 2
        data = d[:, inputSteps:inputSteps+1]

        cond = []
        for i in range(inputSteps):
            cond += [d[:,i:i+1]] # collect input steps
        
        # add mask and probe_slice
        probe_slice = torch.zeros((b, 1, w, h, c), dtype=d.dtype, device=d.device)
        mask = get_mask(probe_slice, batch['probe_idcs'])
        # concat mask
        data = einops.rearrange(data, 'b s c w h -> b s w h c')
        probe_slice[mask] = data[mask]
        data = einops.rearrange(data, 'b s w h c -> b s c w h')
        probe_slice = einops.rearrange(probe_slice, 'b s w h c -> b s c w h')
        mask = einops.rearrange(mask, 'b s w h -> b s 1 w h')
        cond.append(probe_slice)
        cond.append(mask.float())
            
        conditioning = torch.concat(cond, dim=2) # combine along channel dimension

        # input shape (both inputs): B S C W H (D) -> output shape (both outputs): B S nC W H (D)
        noise, predictedNoise = self.model(conditioning=conditioning, data=data)

        loss = F.smooth_l1_loss(noise, predictedNoise)
        
        return loss
    
    def inference(self, batch, numSamples=1, **kwargs):
        d = batch["field"].repeat(numSamples,1,1,1,1) # reuse batch dim for samples
        b, s, w, h, c = d.shape
        # BEGIN custom code
        # bring to shape B S C W H
        d = einops.rearrange(d, 'B S W H C -> B S C W H')
        # END custom code

        prediction = torch.zeros_like(d, device=d.device)
        inputSteps = 2

        for i in range(inputSteps): # no prediction of first steps
            prediction[:,i] = d[:,i]

        n_rollout_steps = d.shape[1]
        for i in range(inputSteps, n_rollout_steps):
            cond = []
            for j in range(inputSteps,0,-1):
                cond += [prediction[:, i-j : i-(j-1)]] # collect input steps
            
            # add mask and probe_slice
            probe_slice = torch.zeros((b, 1, w, h, c), dtype=d.dtype, device=d.device)
            mask = get_mask(probe_slice, batch['probe_idcs'])
            # concat mask
            data = d[:,i:i+1] # set probes to next timestep
            data = einops.rearrange(data, 'b s c w h -> b s w h c')
            probe_slice[mask] = data[mask]
            data = einops.rearrange(data, 'b s w h c -> b s c w h')
            probe_slice = einops.rearrange(probe_slice, 'b s w h c -> b s c w h')
            mask = einops.rearrange(mask, 'b s w h -> b s 1 w h')
            cond.append(probe_slice)
            cond.append(mask.float())
                
            cond = torch.concat(cond, dim=2) # combine along channel dimension

            result = self.model.inference(conditioning=cond, data=data) # auto-regressive inference
            # BEGIN why replace?
            # result[:,:,-len(simParams):] = d[:,i:i+1,-len(simParams):] # replace simparam prediction with true values
            # result[:,:,:] = d[:,i:i+1,:] # replace simparam prediction with true values
            # END
            prediction[:,i:i+1] = result

        # prediction = torch.reshape(prediction, (numSamples, -1, d.shape[1], d.shape[2], d.shape[3], d.shape[4]))
        # BEGIN custom code
        # bring to shape B S W H C
        prediction = einops.rearrange(prediction, 'B S C W H -> B S W H C')
        # END custom code
        return prediction
    
    
def get_mask(x_1, probe_idcs):
    # create a mask with same shape as x_1 but omit last dim
    # probe_idcs # (b n 2) - n is the number of probes
    mask = torch.zeros(*x_1.shape[:-1], device=x_1.device, dtype=bool)
    mask = set_values_by_indices(mask, probe_idcs)
    return mask
