""" Compute validation accuracy across different parameters of Euler equation. """
import sys
import os
from pathlib import Path
current_path = Path(os.getcwd())
sys.path.append(str(current_path))
import argparse
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, Sampler

from src.utils import OutDf, load_model, standardize, divide_grid, nrmpe, to_torch
from src.sf_euler_data import SfEulerDataModule
from src.global_constants import DATA_PATHS, STAT_PATH


class CustomSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def compute_accuracy(dataset, data_idces, batch_size, n_future, model):
    # DATA
    print(f"data indices: {data_idces[0]} -> {data_idces[-1]}")
    print("Creating dataloader...")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=CustomSampler(data_idces))
    # COMPUTE
    with torch.no_grad():
        counts = 0
        nrmse_raw = torch.zeros(n_future, device='cuda:0')
        for batch_idx, batch in enumerate(tqdm(dataloader, total=len(dataloader), desc="Processing Batches")):
            inp, tar = to_torch(batch['input_fields']), to_torch(batch['output_fields'])
            if inp.shape[1] > 16:
                tar = torch.cat([inp[:,16:,...], tar], dim=1)[:,:n_future,...]
                inp = inp[:,:16,...]
            # forward
            if params.model != "ardiff":
                output, _ = model(
                    inp, predict_normed=True, n_future_steps=n_future,
                )
            else:
                # the diffusion model uses as context 2 steps in the past
                inp = inp[:,[-2,-1],...]
                # preprocessing
                spatial_dims = tuple(range(3,inp.squeeze(-1,-2).ndim))
                inp, mean, std = standardize(inp, dims=(1,*spatial_dims), return_stats=True)
                # t-2, t-1 steps as conditioning
                conditionning = torch.cat([inp[:,[-2],...], inp[:,[-1],...]], dim=2)
                # t step  # won't be used anyway
                data = inp[:,[-1],...]
                output = model(conditionning, data, n_future_steps=n_future)
                # unnormalize
                output = output * std + mean
            # nrmse
            spatial_dims = (3, 4) # Assume 0, 1 are B, C  # for 1d, will also
            nrmse_raw += nrmpe(tar, output, dims=spatial_dims, p=2.0, avg=False).mean((0,2))
            counts += 1
    nrmse_raw = nrmse_raw.cpu() / counts
    print(f"nrmse_raw t+1 {nrmse_raw[0].item():.3f} counts {counts}")
    return nrmse_raw


def get_args():
    parser = argparse.ArgumentParser(description='')
    # multiprocessing args
    parser.add_argument('-ntot', type=int, default=1, help='Number of tasks')
    parser.add_argument('-tid', type=int, default=0, help='Current task')
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    ################################################################
    epoch = 100
    ndata = 1024 * 32
    split = 'val'
    args = get_args()
    ################################################################

    dset_name = 'euler'
    GAMMAS = [1.13,1.22,1.3,1.33,1.365,1.4,1.404,1.453,1.597,1.76]
    MODEL_NAMES_EULER = [
        'FINAL_unet_single_eulerpBC',
        'FINAL_fno_single_eulerpBC',
        'FINAL_ard_single_eulerpBC',
        'FINAL_TF_single_eulerpBC',
        'FINAL_icnpde_single_eulerpBC',
    ]

    data_paths = DATA_PATHS[dset_name]

    batch_size = 16
    n_future = 1

    # GRID of parameters to compute for this task
    grid = list(range(min(ndata,340000)))  # 340000: nb of samples in the val set for n_future = 1
    grid = divide_grid(grid, args.ntot, args.tid)
    
    for model_name in MODEL_NAMES_EULER:
        save_name = f"rollout_{model_name}_E{epoch or 'none'}_{dset_name}_{split}_NP16_NF{n_future}_gammas_bis"
        store_results = OutDf(STAT_PATH / save_name)
        # MODEL
        model, params = load_model(model_name, epoch=epoch, pretrained_MPP=False)
        model = model.eval()
        print("Model: ", model_name)
        for gamm in GAMMAS:
            print(f"Gamma: {gamm}")
            # DATA 
            data_module = SfEulerDataModule(
                base_path=data_paths[0],
                dataset_name=data_paths[1],
                resolution=data_paths[2],
                include_filters=[f'gamma_{gamm}_'],
                batch_size=1,  # will be overriden later 
                n_steps_input=16,
                n_steps_output=n_future,
                max_rollout_steps=n_future,
                data_workers=0,
                world_size=1,
                rank=0,
            )
            dataset = {
                'train': data_module.rollout_train_dataset,
                'val': data_module.rollout_val_dataset,
                'test': data_module.rollout_test_dataset,
            }[split]
    
            nrmse_raw = compute_accuracy(dataset, grid, batch_size, n_future, model)

            # SAVING
            for it in range(1,n_future+1):
                stat_dict = {
                    'gamma': gamm,
                    'nrmse_raw': nrmse_raw[it-1].item()
                }
                store_results.add(
                    ntot=args.ntot, tid=args.tid, 
                    type="nrmse",
                    it=it,
                    **stat_dict
                )
            store_results.save()
            print("Saved to: ", store_results.path)
