
""" Compute rollout performances on validation dataset. """
import sys
import os
from pathlib import Path
current_path = Path(os.getcwd())
sys.path.append(str(current_path))
import argparse
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Sampler

from src.utils import OutDf, load_model, standardize, divide_grid, nrmpe, to_torch
from src.sf_euler_data import SfEulerDataModule
from src.global_constants import DATA_PATHS, MODEL_NAMES, STAT_PATH
from src.data.data_utils import DSET_NAME_TO_OBJECT


class CustomSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def rollout_validate(
    dataset, data_idces, batch_size, n_future, split,
    model_name, epoch,
    ntot, tid
):
    # DATA
    print(f"data indices: {data_idces[0]} -> {data_idces[-1]}")
    print("Creating dataloader...")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=CustomSampler(data_idces))
    # MODEL
    print("Loading model...")
    model, params = load_model(model_name, epoch=epoch, pretrained_MPP=False)
    model = model.eval()
    # COMPUTE
    with torch.no_grad():
        counts = 0
        nrmse_raw = torch.zeros(n_future, device='cuda:0')
        for batch_idx, batch in enumerate(tqdm(dataloader, total=len(dataloader), desc="Processing Batches")):
            inp, tar = to_torch(batch['input_fields']), to_torch(batch['output_fields'])
            if inp.shape[1] > 16:
                tar = torch.cat([inp[:,16:,...], tar], dim=1)[:,:n_future,...]
                inp = inp[:,:16,...]
            # forward
            if params.model != "ardiff":
                output, _ = model(
                    inp, predict_normed=True, n_future_steps=n_future,
                )
            else:
                # the diffusion model uses as context 2 steps in the past
                inp = inp[:,[-2,-1],...]
                # preprocessing
                spatial_dims = tuple(range(3,inp.squeeze(-1,-2).ndim))
                inp, mean, std = standardize(inp, dims=(1,*spatial_dims), return_stats=True)
                # t-2, t-1 steps as conditioning
                conditionning = torch.cat([inp[:,[-2],...], inp[:,[-1],...]], dim=2)
                # t step  # won't be used anyway
                data = inp[:,[-1],...]
                output = model(conditionning, data, n_future_steps=n_future)
                # unnormalize
                output = output * std + mean
            # nrmse
            spatial_dims = (3, 4) # Assume 0, 1 are B, C  # for 1d, will also
            nrmse_raw += nrmpe(tar, output, dims=spatial_dims, p=2.0, avg=False).mean((0,2))
            counts += 1
    nrmse_raw = nrmse_raw.cpu() / counts
    print(f"nrmse_raw t+1 {nrmse_raw[0].item():.3f} counts {counts}")
    # SAVING
    save_name = f"rollout_{model_name}_E{epoch or 'none'}_{dset_name}_{split}_NP16_NF{n_future}"
    store_results = OutDf(STAT_PATH / save_name)
    for it in range(1,n_future+1):
        stat_dict = {
            'nrmse_raw': nrmse_raw[it-1].item()
        }
        store_results.add(
            ntot=ntot, tid=tid, 
            type="nrmse",
            it=it,
            **stat_dict
        )
    store_results.save()
    print("Saved to: ", store_results.path)


def get_args():
    parser = argparse.ArgumentParser(description='')
    # multiprocessing args
    parser.add_argument('-ntot', type=int, default=1, help='Number of tasks')
    parser.add_argument('-tid', type=int, default=0, help='Current task')
    # script args
    parser.add_argument('-dset', type=str, help='Dataset name')
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    ################################################################
    epoch = 300
    ndata = 1024 * 32
    split = 'val'
    args = get_args()
    ################################################################

    dset_name = args.dset

    batch_size = 16
    FUTURE_STEPS = {
        'burgers': 16,
        'diffre2d': 16,
        'incompNS': 16,
        'compNS': 4,
        'compNS512': 4,
        'shearflow': 16,
        'euler': 16,
    }
    n_future = FUTURE_STEPS[dset_name]

    # DATA 
    print("Loading data...")
    data_paths = DATA_PATHS[dset_name]
    if dset_name in ['shearflow', 'euler']:
        data_module = SfEulerDataModule(
            base_path=data_paths[0],
            dataset_name=data_paths[1],
            resolution=data_paths[2],
            batch_size=1,  # batch_size > 1 seems not to work on rollout datasets 
            n_steps_input=16,
            n_steps_output=n_future,
            max_rollout_steps=n_future,
            data_workers=0,
            world_size=1,
            rank=0,
        )
        dataset = {
            'train': data_module.rollout_train_dataset,
            'val': data_module.rollout_val_dataset,
            'test': data_module.rollout_test_dataset,
        }[split]
    else: # PDEBench dataset
        reso = [[int(d) for d in tuple(str(res).split('x'))] for res in data_paths[2]]
        dataset = DSET_NAME_TO_OBJECT[dset_name](
            data_paths[0], reso, data_paths[2], n_steps=16+n_future-1, dt=1, train_val_test=(.8, .1, .1), split=split
        )

    # GRID of parameters to compute for this task
    grid = list(range(min(ndata,len(dataset))))
    grid = divide_grid(grid, args.ntot, args.tid)
    for model_name in MODEL_NAMES[dset_name]:
        print("Model: ", model_name)
        rollout_validate(
            dataset, grid, batch_size, n_future, split,
            model_name, epoch,
            args.ntot, args.tid
        )
