from torch_geometric.loader import DataLoader
from tqdm.auto import tqdm
from epyt_flow import utils
from epyt.epanet import ToolkitConstants
import experiments.random_hydraulics_dataset
import utils as functions
import torch
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
import os
from baselines.pde_gnn import PDEFunction, PDEModelInputWrapper
from baselines.nn_conv import NNConvWrapper
import pickle
from sklearn.metrics import explained_variance_score
import pandas as pd
import modules
from model.advection_reaction_layer import AdvectionReactionLayer
from model.advection_layer import AdvectionLayer
from model.advection_model_mp import AdvectionModelMP
from experiments.experiment_configs import CONFIGS
from itertools import product
from matplotlib import colors
import matplotlib.legend as mlegend
import shapely
from copy import deepcopy

try:
   mp.set_start_method('spawn', force=True)
   print("spawned")
except RuntimeError:
   pass

N_SECONDS = utils.to_seconds(hours=12)
HYDRAULIC_STEP = 1 * 60
QUALITY_TIMESTEP = 1
PATTERN_STEP = HYDRAULIC_STEP * 2
PATTERN_N_WAVES = 3
SEED = 42

RESULTS_DIR = 'Results/learning_reactions'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diameter_range = 60, 250 # [mm] --> 6 - 25 [cm]
length_range = 0.1, 80   # [m]
demand_range = 0., 5    # [m³/h]

def init_dataset(dataset_config, cT, Tau): #, inp_file, msx_file, sources_at, size, batch_size, idx_offset=0):
    inp_file = dataset_config['inp_file']
    msx_file = dataset_config['msx_file']
    sources_at = dataset_config['sources_at']
    idx_offset = dataset_config['idx_offset']
    equal_demand_at_nodes = dataset_config['equal_demand_at_nodes']
    size = dataset_config['size']
    
    pattern_length = N_SECONDS // HYDRAULIC_STEP
    pattern = functions.create_wavy_pattern(pattern_length, PATTERN_N_WAVES, HYDRAULIC_STEP, SEED)
    ts = np.linspace(0, N_SECONDS, pattern_length) / 60 / 60
    
    plt.figure(figsize=(14, 2), tight_layout=True)
    plt.plot(ts, pattern)
    plt.title('Cl Injection Pattern');plt.xlabel('Time [hours]');_=plt.ylabel('Cl injection ($\\frac{mg}{l}$)')
    plt.gcf().savefig(os.path.join(RESULTS_DIR, 'injection_pattern.pdf'))
    
    ds = experiments.random_hydraulics_dataset.RandomHydraulicsDataset(
        inp_file, N_SECONDS, cT, Tau, HYDRAULIC_STEP, QUALITY_TIMESTEP,
        PATTERN_STEP, diameter_range, length_range, demand_range,
        pattern, sources_at, f_msx_in=msx_file, 
        equal_demand_at_nodes=equal_demand_at_nodes,
        source_type=ToolkitConstants.EN_SETPOINT, 
        data_path=dataset_config['data_path']
    )
    ds = experiments.random_hydraulics_dataset.TorchRandomHydraulicsDataset(
        size, ds, device=device, idx_offset=idx_offset
    )
    return ds

def init_model(config):
    model_name = config['model_name']
    if model_name == 'PDE-GNN':
        model = PDEModelInputWrapper(
            PDEFunction(1, 1, config, device),
            config
        ).to(device)
    if model_name == 'NNConv':
        model = NNConvWrapper(
            config['cT'], config['Tau'], config['hidden_dim'], config['num_layers'], config['aggr']
        ).to(device)
    elif model_name == 'MeGA-MP':
        advection_op = modules.AdvectionModuleGridSampleDynamic(interpolation_mode=config['interpolation_mode'])
        mask_op = modules.MaskingModuleSigmoid()
        mask_op.mask_temp = config['mask_temp']
        if config['layer'] == 'advect_react':
            layer = AdvectionReactionLayer(advection_op, mixing_at_nodes=config['mixing_at_nodes'])
        elif config['layer'] == 'advect':
            layer = AdvectionLayer(advection_op, mixing_at_nodes=config['mixing_at_nodes'])
        else:
            raise NotImplementedError(f'Layer {config["layer"]} not implemented.')
        model = AdvectionModelMP(
            layer, mask_op, max_msg_passing_rounds=config['max_msg_passing_rounds'], progress=False
        ).to(device)
    return model

def prepare_features(model_inputs):
    model_inputs['edge_diameter_scaled'] = (
        model_inputs['edge_diameter'] / model_inputs['max_diameter'][0]
    )
    model_inputs['delay_steps_scaled'] = (
        model_inputs['delay_steps'] / model_inputs['max_delay_steps'][0]
    )
    model_inputs['edge_capacity_scaled'] = (
        model_inputs['edge_capacity'] / model_inputs['max_edge_capacity'][0]
    )
    model_inputs['flows_scaled'] = (
        model_inputs['flows'] / model_inputs['max_flow'][0]
    )
    model_inputs['x'] = model_inputs['x'].unsqueeze(-1)
    model_inputs['boundary_values'] = model_inputs['boundary_values'].unsqueeze(-1)
    return model_inputs
    
def train(model, dataset, learning_rate, epochs, loss_fn, cT, Tau, scheduler_factor=0.2, scheduler_patience=20, clip_grad_norm=0.0001):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #, betas=(0.9, 0.95))
    schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=scheduler_factor, patience=scheduler_patience, threshold=1e-5
    )
    
    _losses = []
    progress = tqdm(range(epochs))

    model.to(device)

    for epoch in progress:
        epoch_loss = 0
        for model_inputs in dataset:
            optimizer.zero_grad()
            model_inputs = prepare_features(model_inputs)
            pred, *_ = model(cT=cT, Tau=Tau, n_steps=1, **model_inputs)
            y_true = model_inputs['y'][:,:cT+Tau+1]
            pred = pred.squeeze(-1)
            
            # boundary nodes will not be predicted, values assumed known
            # mask boundary nodes
            mask = ~torch.eye(len(pred), device=pred.device)[model_inputs.boundary_index].sum(0).bool()
            pred = pred[mask]
            y_true = y_true[mask]
            
            l1_loss = torch.nn.L1Loss()(pred[:,cT:], y_true[:,cT:])
            
            if loss_fn == 'l1':
                loss = l1_loss
            elif loss_fn == 'mape':
                loss = ((pred[:,cT:] - y_true[:,cT:]) / (1e-3 + y_true[:,cT:])).abs().mean()
            elif loss_fn == 'huber':
                loss = torch.nn.HuberLoss(delta=0.3)(pred[:,cT:], y_true[:,cT:])
            elif loss_fn == 'mse':
                loss = torch.nn.MSELoss()(pred[:,cT:], y_true[:,cT:])
            elif loss_fn == 'relative_l1':
                loss = torch.norm(pred[:,cT:] - y_true[:,cT:], dim=0) / (torch.norm(y_true[:,cT:], dim=0)+1e-5)
                loss = loss.mean()
            elif loss_fn == 'relative_l2':
                loss = torch.norm(pred[:,cT:] - y_true[:,cT:], dim=0)**2 / (torch.norm(y_true[:,cT:], dim=0)**2+1e-5)
                loss = loss.mean()
            else:
                raise NotImplementedError(f'A loss function with name {loss_fn} is not implemented.')
            #loss = torch.nn.L1Loss()(torch.log(1e-5+pred[:,cT:]), torch.log(1e-5+y_true[:,cT:]))
            loss.backward()
            epoch_loss += loss.detach()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
            
            optimizer.step()
            _losses.append((l1_loss.detach().mean().cpu(), loss.detach().mean().cpu()))
        progress.set_description(f'{np.mean([_losses[-1*i][0].numpy() for i in range(1, 1+len(dataset))]):.5f}  {optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')
        schedule.step(epoch_loss)
    
    return model, _losses
    
def run_experiment(cT, Tau, dataset_config, model_config, model_path):    
    ds = init_dataset(dataset_config, cT, Tau)
    data_loader = DataLoader(ds, batch_size=model_config['batch_size'], shuffle=True)

    print(model_config)

    model = init_model(model_config)
    model, losses = train(
        model, data_loader, model_config['learning_rate'], model_config['epochs'], 
        model_config['loss_fn'], cT, model_config['train_Tau'], 
        clip_grad_norm=model_config['clip_grad_norm'],
        scheduler_patience=model_config['scheduler_patience']
    )
    
    model_file = os.path.join(model_path, 'model.pt')
    os.makedirs(model_path, exist_ok=True)
    torch.save(model.state_dict(), model_file)
    
    # model_inputs = next(iter(ds))
    # Tau = model_config['eval_Tau']
    # model_inputs = prepare_features(model_inputs)
    # pred, *_ = model(cT=cT, Tau=Tau, n_steps=1, **model_inputs)
    # _y_true = model_inputs['y'][:,:cT+Tau].cpu()
    
    # fig, ax = plt.subplots(1, 2, figsize=(15, 4))
    # # pred, edge_passes, _, agg_time, aggs_all = model(cT=cT, Tau=Tau, n_steps=1, **model_inputs)
    # # pred = decoder(pred).squeeze(-1)
    # _ = ax[0].plot(pred.detach().cpu().T[:, ::3], color='C0')
    # _ = ax[0].plot(_y_true.T[:, ::3], '-', color='C1', linestyle='--')
    # # _ = ax[0].plot(model_inputs['boundary_values'].cpu().T[:10, [0]], '-o', color='k', linestyle='--', zorder=0)
    # # _ = ax[0].plot(model_inputs['boundary_values'].cpu().T[:10], color='C1', linestyle='--')
    # l1loss, modelloss = zip(*losses)
    # # _ = ax[1].plot(l1loss[150:])
    # _ = ax[1].semilogy(l1loss[200::13])
    # fig.savefig(os.path.join(model_path, 'train_loss.pdf'))

def plot_with_map(data, flow_field, topology, sources_at=[], plt_kwargs=[]):
    fig, ax = plt.subplot_mosaic(
        'WWW0\nWWW1\nWWW2\nWWW3\nWWW4\nWWW5\nWWW6\nWWW7',
        figsize=(14, 7), constrained_layout=True
    )

    nsteps = data[0].shape[1]
    n = 8
    plot_nodes = np.random.RandomState(32).choice(topology.node_name_list, size=n)#[:6]
    highlight_colors = [(0.89, 0.48, 0.15, 1.0)] * n # plt.cm.gnuplot(np.linspace(0, 0.8, n))
    lg = (0.9, 0.9, 0.9, 1.0)
    
    edge_data = np.abs(flow_field).mean(1)
    edge_norm = colors.LogNorm(edge_data.min(), edge_data.max())
    edge_cmap = plt.cm.Blues
    
    functions.plot_graph_from_topology(topology, node_colors=lg, edge_colors=edge_cmap(edge_norm(edge_data)), ax=ax['W'], node_size=15, width=3)
    # ax['W'].axis('on')
    ax['W'].margins(-0.02)
    # gt_data = gt_data[:,:nsteps]

    pos = topology.query_node_attribute('coordinates')
    ts = np.linspace(0, nsteps*HYDRAULIC_STEP+0.1, nsteps) / 60 / 60
    
    inj_node_plot = ax['W'].scatter(*np.transpose([pos[s] for s in sources_at]), marker='*', zorder=2, 
                                    s=400, label='Injection Nodes', color=(0.4, 0.6, 0.9, 1), ec='k')

    for i, node in enumerate(plot_nodes):
        node_idx = topology.node_name_list.index(node)
        for d, kwargs in zip(data, plt_kwargs):
            ax[str(i)].plot(ts, d[node_idx,:nsteps], **kwargs)
        ax[str(i)].set_xlim(min(ts), max(ts))
        ax[str(i)].set_ylim(-0.05, 1.05)
        ax[str(i)].set_ylabel(i+1, rotation=0, fontdict={'weight':'bold', 'fontsize':13, 'color' : highlight_colors[i]}, va='center')
        # ax['W'].text(*np.add(pos[node], 10), i+1, fontdict={'weight':'bold', 'fontsize':13, 'color' : highlight_colors[i]}, va='bottom', ha='right', zorder=3)
        # ax['W'].scatter(*pos[node], c=[highlight_colors[i]], zorder=3)
        ax['W'].text(*pos[node], i+1, fontdict={'weight':'bold', 'fontsize':13, 'color' : highlight_colors[i]}, va='center', ha='center', zorder=3)
        ax['W'].scatter(*pos[node], color=(1,1,1,0.8), zorder=2, s=650)
        ax['W'].scatter(*pos[node], color=(*highlight_colors[i][:-1],0.25), zorder=2, s=650)
        if i < n-1:
            ax[str(i)].set_xticklabels([])
        else:
            ax[str(i)].set_xlabel('Time [hours]')
    
    
    handles, labels = ax['W'].get_legend_handles_labels()
    linehandles, linelabels = ax['0'].get_legend_handles_labels()

    del handles[2]; del labels[2]

    x1, y1, w1, h1 = ax['W'].bbox.bounds
    x, y = fig.transFigure.inverted().transform([x1 + 0.55 * w1, y1])
    w, _ = fig.transFigure.inverted().transform([w1 * 0.45, y1])
    cax = fig.add_axes((x, y, w, 0.03))

    pipe_h = plt.Rectangle([0,0], 3, 1, color=plt.cm.Blues(0.9), label='Pipe')
    ax['W'].legend([pipe_h, *handles[1:], *linehandles], [*labels, *linelabels])
    
    cb = plt.colorbar(plt.cm.ScalarMappable(norm=edge_norm, cmap=edge_cmap), cax=cax, orientation='horizontal')
    cb.set_label(r'Avg. Flow Velocity [$\frac{m}{s}$]')

    return fig, ax
    
def plot_on_map(data, topology, sources_at=[], plt_kwargs=[]):
    fig, ax = plt.subplots(figsize=(14, 8))

    n = 8
    plot_nodes = topology.node_name_list
    np.random.RandomState(5).shuffle(plot_nodes)
    lg = (0.9, 0.9, 0.9, 1.0)

    functions.plot_graph_from_topology(topology, node_colors=lg, edge_colors=lg, ax=ax, node_size=15, width=3)

    nsteps = data[0].shape[1]

    ts = np.linspace(0, nsteps*HYDRAULIC_STEP, nsteps) / 60 / 60
    pos = topology.query_node_attribute('coordinates')

    axis_polygon = shapely.MultiPolygon()
    # make sure nothing occludes source nodes
    axis_polygon = axis_polygon.union(shapely.MultiPoint(
        [ fig.transFigure.inverted().transform(ax.transData.transform(pos[s])) for s in sources_at ]
    ))
    
    if len(sources_at):
        xys = np.transpose([pos[s] for s in sources_at])
        inj_node_plot = ax.scatter(*xys, marker='*', zorder=2, s=200, label='Injection Nodes', c='C0')
    
    visible_nodes = []

    for i, node in enumerate(plot_nodes):
        node_idx = topology.node_name_list.index(node)
        x, y = fig.transFigure.inverted().transform(ax.transData.transform(pos[node]))
        x, y, w, h = (x, y, 0.18*0.7, 0.1*0.8)
        
        ax_poly = shapely.Polygon.from_bounds(x, y, x+w, y+h).buffer(0.03)
        if ax_poly.intersects(axis_polygon):
            continue
        axis_polygon = axis_polygon.union(shapely.Polygon.from_bounds(x, y, x+w, y+h))
        visible_nodes.append(node)
        mini_ax = fig.add_axes((x, y, w, h))
        
        ax.scatter(*pos[node], s=100, color=(0.3, 0.3, 0.3, 1.0), zorder=2)
        line_handles = []
        line_labels = []

        for d, kwargs in zip(data, deepcopy(plt_kwargs)):
            line_labels.append(kwargs.pop('label'))
            line_plot, = mini_ax.plot(ts, d[node_idx, :nsteps], **kwargs)
            line_handles.append(line_plot)
        
        #mini_ax.plot(ts, gt_data[node_idx], linestyle=':', linewidth=2, color='C1')
        mini_ax.set_xlim(min(ts), max(ts))
        mini_ax.set_xticks([round(max(ts))])
        mini_ax.set_yticks([1])
        mini_ax.set_ylim(-0.05, 1.05)
        mini_ax.set_fc((1,1,1,0.6))
    
    handles, labels, _ = mlegend._parse_legend_args([ax])
    ax.legend([*handles, *line_handles], [*labels, *line_labels])
    return fig, ax

def evaluate(cT, Tau, dataset_config, model_config, model_path, batch_size=8):
    model = init_model(model_config)

    if not (model_config['model_name'] == 'MeGA-MP' and model_config['layer'] == 'advect'):
        model_file = os.path.join(model_path, 'model.pt')
        model.load_state_dict(torch.load(model_file, weights_only=True))
    
    ds = init_dataset(dataset_config, cT, Tau)
    data_loader = DataLoader(ds, batch_size=batch_size)
    Tau = model_config['eval_Tau']

    predictions = []
    ground_truths = []
    flow_fields = []
    
    with torch.no_grad():
        for model_inputs in tqdm(data_loader):
            model_inputs = prepare_features(model_inputs)
            pred, *_ = model(cT=cT, Tau=Tau, n_steps=1, **model_inputs)
            pred = pred.squeeze(-1)
            # boundary nodes will not be predicted, values assumed known
            # mask boundary nodes
            mask = ~torch.eye(len(pred), device=pred.device)[model_inputs.boundary_index].sum(0).bool()
            pred[mask][:,:] = model_inputs['y'][mask] # set values at dirichlet boundaries
            predictions.append(pred.detach().cpu())
            ground_truths.append(model_inputs['y'].cpu())
            flow_fields.append(model_inputs['flow_field'].cpu())
            
    predictions = torch.cat(predictions, dim=0)
    ground_truths = torch.cat(ground_truths, dim=0)
    flow_fields = torch.cat(flow_fields, dim=0)
    
    return {
        'predictions' : predictions,
        'ground_truths' : ground_truths,
        'flow_fields' : flow_fields,
        'sources_at' : ds.hydraulics_ds.sources_at,
        'topology' : ds.hydraulics_ds.topology,
    }
    

EXP_SETUPS = { 
    0 : {
        'data_splits' : {
            'train' : 512, 
            # 'val' : 128, 
            'test' : 128
        },
        'train' : {
            'inp_file' : 'networks/Hanoi.inp',
            'msx_file' : 'networks/simple_reactions.msx',
            'sources_at' : ['1'],
            'equal_demand_at_nodes' : False,
            'data_path' : './dataset',
        },
        'eval' : {
            'Hanoi' : {
                'inp_file' : 'networks/Hanoi.inp',
                'msx_file' : 'networks/simple_reactions.msx',
                'sources_at' : ['1'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset',
            },
            'L-Town' : {
                'inp_file' : 'networks/l_town_no_tanks.inp',
                'msx_file' : 'networks/simple_reactions.msx',
                'sources_at' : ['R1', 'R2'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset',
            },
            # 'BIWS' : {
            #     'inp_file' : 'networks/BIWS_no_tanks.inp',
            #     'msx_file' : 'networks/simple_reactions.msx',
            #     'sources_at' : ['R1', 'W1_RI', 'W2_SA', 'W3_AB', 'W4_SM', 'W5_PL'],
            #     'equal_demand_at_nodes' : False,
            #     'data_path' : './dataset',
            # }
        },
    },
    1 : {
        'data_splits' : {
            'train' : 512, 
            # 'val' : 128, 
            'test' : 128
        },
        'train' : {
            'inp_file' : 'networks/Hanoi.inp',
            'msx_file' : 'networks/ltown.msx',
            'sources_at' : ['1'],
            'equal_demand_at_nodes' : False,
            'data_path' : './dataset_no_reactions',
        },
        'eval' : {
            'Hanoi' : {
                'inp_file' : 'networks/Hanoi.inp',
                'msx_file' : 'networks/ltown.msx',
                'sources_at' : ['1'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset_no_reactions',
            },
            'L-Town' : {
                'inp_file' : 'networks/l_town_no_tanks.inp',
                'msx_file' : 'networks/ltown.msx',
                'sources_at' : ['R1', 'R2'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset_no_reactions',
            }
        },
    },
    2 : {
        'data_splits' : {
            'train' : 512, 
            # 'val' : 128, 
            'test' : 128
        },
        'train' : {
            'inp_file' : 'networks/Hanoi.inp',
            'msx_file' : 'networks/complicated_reactions.msx',
            'sources_at' : ['1'],
            'equal_demand_at_nodes' : False,
            'data_path' : './dataset_complicated',
        },
        'eval' : {
            'Hanoi' : {
                'inp_file' : 'networks/Hanoi.inp',
                'msx_file' : 'networks/complicated_reactions.msx',
                'sources_at' : ['1'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset_complicated',
            },
            'L-Town' : {
                'inp_file' : 'networks/l_town_no_tanks.inp',
                'msx_file' : 'networks/complicated_reactions.msx',
                'sources_at' : ['R1', 'R2'],
                'equal_demand_at_nodes' : False,
                'data_path' : './dataset_complicated',
            },
            # 'BIWS' : {
            #     'inp_file' : 'networks/BIWS_no_tanks.inp',
            #     'msx_file' : 'networks/complicated_reactions.msx',
            #     'sources_at' : ['R1', 'W1_RI', 'W2_SA', 'W3_AB', 'W4_SM', 'W5_PL'],
            #     'equal_demand_at_nodes' : False,
            #     'data_path' : './dataset',
            # }
        },
    },
}
    
if __name__ == '__main__':
    # inp_file = dataset_config['inp_file']
    # msx_file = dataset_config['msx_file']
    # sources_at = dataset_config['sources_at']
    # batch_size = dataset_config['batch_size']
    # idx_offset = dataset_config['idx_offset']
    # size = dataset_config['size']
    from utils import COLOR_THEME
    
    model_names = ['NNConv', 'PDE-GNN', 'MeGA-MP'] # 
    
    for model_name in model_names:
        model_config = CONFIGS[model_name]
        model_config['device'] = device
        model_config['layer'] = 'advect_react'
        
        for k, exp_config in EXP_SETUPS.items():

            model_path = os.path.join(RESULTS_DIR, f'experiment_{k}', model_name)
            results_path = os.path.join(RESULTS_DIR, f'experiment_{k}', model_name, 'results')
            os.makedirs(results_path, exist_ok=True)
            
            exp_config['train']['size'] = exp_config['data_splits']['train']
            exp_config['train']['idx_offset'] = 0

            if k == 1 and model_name == 'MeGA-MP':
                model_config['layer'] = 'advect' # experiment 1 is on pure advection
            else:
                if not os.path.exists(os.path.join(model_path, 'model.pt')):
                    run_experiment(1, model_config['train_Tau'], exp_config['train'], model_config, model_path)
            
            for k, eval_config in exp_config['eval'].items():
                eval_config['size'] = exp_config['data_splits']['test']
                eval_config['idx_offset'] = (
                    exp_config['data_splits']['train'] 
                    #+ exp_config['data_splits']['val']
                )
                result = evaluate(1, model_config['eval_Tau'], eval_config, model_config, model_path)
                result['key'] = k
                
                with open(os.path.join(results_path, f'eval_{k}.pkl'), 'wb') as f:
                    pickle.dump(result, f)

    model_names = ['NNConv', 'PDE-GNN', 'MeGA-MP']
    metric_names = ['MAE (mean)', 'MAE (std)', 'MSE (mean)', 'MSE (std)', 'Exp Var (mean)', 'Exp Var (std)']

    sample_idx = 0
    all_wds = EXP_SETUPS[0]['eval'].keys()
    metrics_index = pd.MultiIndex.from_tuples(product(model_names, metric_names, all_wds))
    
    # generate results
    for k, exp_config in EXP_SETUPS.items():
        results_path = os.path.join(RESULTS_DIR, f'experiment_{k}', model_names[0], 'results')
        out_path = os.path.join(RESULTS_DIR, f'experiment_{k}')
        
        #results_path = os.path.join(RESULTS_DIR, f'experiment_{k}', 'pde_gnn', 'results')
        metrics_df = pd.DataFrame(columns=exp_config['eval'].keys(), index=metrics_index)
        
        for wds, eval_config in exp_config['eval'].items():
            with open(os.path.join(results_path, f'eval_{wds}.pkl'), 'rb') as f:
                result = pickle.load(f)
                        
            topology = result['topology']
            n_nodes = topology.num_nodes
            batch_slice = slice(n_nodes*sample_idx, n_nodes*(sample_idx+1))
            y_true = result['ground_truths']
            flow_fields = result['flow_fields']
            sources_at = result['sources_at']
            y_true = y_true[batch_slice,:]
            preds = []
            
            # load predictions from other methods
            for model_name in model_names:
                results_path = os.path.join(RESULTS_DIR, f'experiment_{k}', model_name, 'results')
                with open(os.path.join(results_path, f'eval_{wds}.pkl'), 'rb') as f:
                    pred = pickle.load(f)['predictions'][batch_slice]
                    preds.append(pred[:,:y_true.shape[1]])

            fig, ax = plot_on_map(
                [*preds, y_true], topology, sources_at=sources_at, plt_kwargs=[
                    *[dict(linestyle='-', linewidth=2, color=f'C{midx}', label=m) for midx, m in enumerate(model_names)],
                    dict(linestyle=':', linewidth=2, color='k', label='Ground Truth')
                ]
            )

            figure_file = os.path.join(out_path, f'{wds}_advection_reaction_on_map.pdf')
            fig.savefig(figure_file, bbox_inches='tight')

            fig, ax = plot_with_map(
                [*preds, y_true], flow_fields[:topology.num_links].cpu(), topology, sources_at=sources_at, plt_kwargs=[
                    *[dict(linestyle='-', linewidth=2, color=COLOR_THEME[midx], label=m) for midx, m in enumerate(model_names)],
                    dict(linestyle=':', linewidth=2, color='k', label='Ground Truth')
                ]
            )
            
            figure_file = os.path.join(out_path, f'{wds}_advection_reaction_with_map.pdf')
            fig.savefig(figure_file, bbox_inches='tight')
            
            for pred, model_name in zip(preds, model_names):
                mae = torch.nn.L1Loss(reduction='none')(pred, y_true)
                mse = torch.nn.MSELoss(reduction='none')(pred, y_true)
                exp_var = explained_variance_score(np.nan_to_num(y_true), np.nan_to_num(pred), multioutput='raw_values', force_finite=False)
                metrics_df.loc[model_name, 'MAE (mean)', wds] = mae.mean().item()
                metrics_df.loc[model_name, 'MAE (std)',  wds] = mae.std().item()
                metrics_df.loc[model_name, 'MSE (mean)', wds] = mse.mean().item()
                metrics_df.loc[model_name, 'MSE (std)',  wds] = mse.std().item()
                metrics_df.loc[model_name, 'Exp Var (mean)',  wds] = exp_var.mean().item()
                metrics_df.loc[model_name, 'Exp Var (std)',   wds] = exp_var.std().item()
        
        metrics_df.to_csv(os.path.join(out_path, 'metrics.csv'))
        