"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse, sys, os, time
from typing import Tuple
import yaml
from tqdm import tqdm
import numpy as np
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_packed_sequence
import gpytorch
from src.nn import OUFlow
from benchmark import ouflow
# from benchmark import gpr, latentsde, acssm, dspd
from src.utils import fix_random



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def energy_score(truth: Tensor, samples: Tensor, b: float=1.0, batch_size: int=None) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Compute the energy score.

    Parameters
    ----------
    truth: Tensor, shape (n_truth, dim)
        Ground truth.
    samples: Tensor, shape (n_sample, dim)
        Samples from the model distribution.
    b: float, default=1.0
        Exponent of the distance. The range of b is (0, 2).
    batch_size: int, optional
        Batch size for computing the energy score.

    Returns
    -------
    energy_score: Tensor, shape (n_truth)
        Energy score.
    dist_cross: Tensor, shape (n_truth)
        Cross term of the energy score.
    dist_self: Tensor, shape (n_truth)
        Self term of the energy score.
    """
    n_sample = samples.shape[0]
    if batch_size is None:
        dist_cross = (torch.norm(truth.unsqueeze(1) - samples, dim=-1)**b).mean(dim=1)
        dist_self = 0.5*(torch.norm(samples[::2].unsqueeze(0) - samples[1::2].unsqueeze(1), dim=-1)**b).mean(dim=(0, 1))
    else:
        n_truth = truth.shape[0]
        n_batch = (n_sample + batch_size - 1)//batch_size
        dist_cross = torch.zeros(n_truth, device=truth.device)
        dist_self = torch.zeros(n_truth, device=truth.device)
        for i in range(n_batch):
            i_i = i*batch_size
            i_f = min((i + 1)*batch_size, n_sample)
            dist_cross += (torch.norm(truth.unsqueeze(1) - samples[i_i:i_f], dim=-1)**b).sum(dim=1)
            for j in range(n_batch):
                j_i = j*batch_size
                j_f = min((j + 1)*batch_size, n_sample)
                dist_self += 0.5*(torch.norm(samples[i_i:i_f:2].unsqueeze(1) - samples[j_i+1:j_f+1:2], dim=-1)**b).sum(dim=(0, 1))
        dist_cross /= n_sample
        dist_self /= (n_sample//2)**2

    energy_score = dist_cross - dist_self

    return energy_score, dist_cross, dist_self



def energy_distance(samples1: Tensor, samples2: Tensor, b: float=1.0, batch_size: int=None) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """
    Compute the energy distance.

    Parameters
    ----------
    samples1: Tensor, shape (n_samples1, dim)
        Samples from the model distribution.
    samples2: Tensor, shape (n_samples2, dim)
        Other samples from the model distribution.
    b: float, default=1.0
        Exponent of the distance. The range of b is (0, 2).
    batch_size: int, optional
        Batch size for computing the energy distance.

    Returns
    -------
    energy_distance: Tensor
        Energy distance.
    dist_cross: Tensor
        Cross term of the energy distance.
    dist_self1: Tensor
        Self term of the energy distance for samples1.
    dist_self2: Tensor
        Self term of the energy distance for samples2.
    """

    if batch_size is None:
        dist_self1 = (torch.norm(samples1[::2].unsqueeze(0) - samples1[1::2].unsqueeze(1), dim=-1)**b).mean(dim=(0, 1))
        dist_self2 = (torch.norm(samples2[::2].unsqueeze(0) - samples2[1::2].unsqueeze(1), dim=-1)**b).mean(dim=(0, 1))
        dist_cross = (torch.norm(samples1.unsqueeze(0) - samples2.unsqueeze(1), dim=-1)**b).mean(dim=(0, 1))
    else:
        n_sample1 = samples1.shape[0]
        n_sample2 = samples2.shape[0]
        n_batch1 = (n_sample1 + batch_size - 1)//batch_size
        n_batch2 = (n_sample2 + batch_size - 1)//batch_size
        dist_self1 = torch.zeros(1, device=samples1.device)
        dist_self2 = torch.zeros(1, device=samples1.device)
        dist_cross = torch.zeros(1, device=samples1.device)
        for i in range(n_batch1):
            i_i = i*batch_size
            i_f = min((i + 1)*batch_size, n_sample1)
            for j in range(n_batch1):
                j_i = j*batch_size
                j_f = min((j + 1)*batch_size, n_sample1)
                dist_self1 += (torch.norm(samples1[i_i:i_f:2].unsqueeze(1) - samples1[j_i+1:j_f+1:2], dim=-1)**b).sum(dim=(0, 1))
            for j in range(n_batch2):
                j_i = j*batch_size
                j_f = min((j + 1)*batch_size, n_sample2)
                dist_cross += (torch.norm(samples1[i_i:i_f].unsqueeze(1) - samples2[j_i:j_f], dim=-1)**b).sum(dim=(0, 1))
        for i in range(n_batch2):
            i_i = i*batch_size
            i_f = min((i + 1)*batch_size, n_sample2)
            for j in range(n_batch2):
                j_i = j*batch_size
                j_f = min((j + 1)*batch_size, n_sample2)
                dist_self2 += (torch.norm(samples2[i_i:i_f:2].unsqueeze(1) - samples2[j_i+1:j_f+1:2], dim=-1)**b).sum(dim=(0, 1))
        dist_self1 /= (n_sample1//2)**2
        dist_self2 /= (n_sample2//2)**2
        dist_cross /= n_sample1*n_sample2

    energy_distance = 2*dist_cross - dist_self1 - dist_self2

    return energy_distance, dist_cross, dist_self1, dist_self2



def energy_distance_generation(xs_test: Tensor, gen_func, n_sample: int=4096, n_trial: int=64, batch_size: int=256) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float]:
    """
    Calculate energy distance for generation.

    Parameters
    ----------
    xs_test : Tensor
        Time series of the test data. The shape is `(n_scenario, n_time, d)`.
    gen_func : callable
        A generation model. The arguments must be `ts_pred` and `n_sample` and the return value must be `xs_pred`.
        `ts_pred` is a tensor of shape `(n_time_pred, 1)`, where `n_time_pred` is the number of time points in the forecast period.
        `n_sample` is the number of returned samples.
        `xs_pred` is a tensor of shape `(n_sample, n_time_pred, d)`.
    n_sample : int, optional
        The number of samples for the model. The default is 256.
    n_trial : int, optional
        The number of trials for computing the energy distance. The default is 64.
    batch_size : int, optional
        Batch size for computing the energy distance. The default is 256.

    Returns
    -------
    energy_distance: Tensor
        Energy distance between the test data distribution and the model distribution.
    ed_std: Tensor
        Standard deviation of the energy distance.
    dist_cross: Tensor
        Cross term of the energy distance.
    dist_cross_std: Tensor
        Standard deviation of the cross term.
    dist_self1: Tensor
        Self term of the energy distance for the test data distribution.
    dist_self1_std: Tensor
        Standard deviation of the self term for the test data distribution.
    dist_self2: Tensor
        Self term of the energy distance for the model distribution.
    dist_self2_std: Tensor
        Standard deviation of the self term for the model distribution.
    generation_time: float
        Average time for generating samples.
    """
    with torch.no_grad():
        xs_test = xs_test.reshape(-1, xs_test.shape[-1])
        eds = []
        dcs = []
        ds1s = []
        ds2s = []
        generation_time = 0
        for _ in tqdm(range(n_trial)):
            generation_time0 = time.time()
            x_pred = gen_func(torch.zeros([1, 1], device=device), n_sample)[:,0]
            generation_time1 = time.time()
            generation_time += (generation_time1 - generation_time0)/n_trial
            ed, dc, ds1, ds2 = energy_distance(xs_test, x_pred, batch_size=batch_size)
            eds.append(ed[None])
            dcs.append(dc[None])
            ds1s.append(ds1[None])
            ds2s.append(ds2[None])
        eds = torch.cat(eds)
        dcs = torch.cat(dcs)
        ds1s = torch.cat(ds1s)
        ds2s = torch.cat(ds2s)
        ed, ed_std = eds.mean(), eds.std()
        dc, dc_std = dcs.mean(), dcs.std()
        ds1, ds1_std = ds1s.mean(), ds1s.std()
        ds2, ds2_std = ds2s.mean(), ds2s.std()
        return ed, ed_std, dc, dc_std, ds1, ds1_std, ds2, ds2_std, generation_time



def time_averaged_energy_scores_forecasting(ts_test: Tensor, xs_test: Tensor, pred_func, cond_interval: float=1/3, pred_interval: float=2/3, n_sample: int=256, batch_size: int=256, n_interval: int=None) -> Tuple[Tensor, Tensor, Tensor, Tensor, float]:
    """
    Calculate time-averaged energy scores for forecasting.

    Parameters
    ----------
    ts_test : Tensor
        Time points of the test data. The shape is `(n_scenario, n_time, 1)`.
    xs_test : Tensor
        Time series of the test data. The shape is `(n_scenario, n_time, d)`.
    pred_func : callable
        A forecasting model. The arguments must be `ts_cond`, `xs_cond`, `ts_pred`, and `n_sample` and the return value must be `xs_pred`.
        `ts_cond` is a tensor of shape `(n_time_cond, 1)`, where `n_time_cond` is the number of time points in the condition period.
        `xs_cond` is a tensor of shape `(n_time_cond, d)`, where `d` is the dimension of the time series.
        `ts_pred` is a tensor of shape `(n_time_pred, 1)`, where `n_time_pred` is the number of time points in the forecast period.
        `n_sample` is the number of returned samples.
        `xs_pred` is a tensor of shape `(n_sample, n_time_pred, d)`.
    cond_interval : float, optional
        The length of the condition period. The default is 1/3.
    pred_interval : float, optional
        The length of the forecast period. The default is 2/3.
    n_sample : int, optional
        The number of samples for the model. The default is 256.
    batch_size : int, optional
        Batch size for computing the energy score. The default is 256.
    n_interval : int, optional
        The number of evaluated intervals. The default is None

    Returns
    -------
    taess : Tensor
        Time-averaged energy scores for all the valid scenarios and intervals.
    dist_cross : Tensor
        Cross term of the energy score.
    dist_self : Tensor
        Self term of the energy score.
    indices : Tensor
        Indices of the valid scenarios and start of intervals.
    generation_time : float
        Average time for generating samples.
    """
    with torch.no_grad():
        if n_interval is not None and n_interval < ts_test.shape[0]*ts_test.shape[1]:
            idx = torch.randint(0, ts_test.shape[0]*ts_test.shape[1], (n_interval,), device=ts_test.device)
            scenarios = idx // ts_test.shape[1]
            it0s = idx % ts_test.shape[1]
        else:
            scenarios = torch.arange(ts_test.shape[0], device=ts_test.device).repeat_interleave(ts_test.shape[1])
            it0s = torch.arange(ts_test.shape[1], device=ts_test.device).repeat(ts_test.shape[0])
        taess = []
        dcs = []
        dss = []
        indices = []
        generation_time = 0
        for scenario, it0 in tqdm(zip(scenarios, it0s), total=len(scenarios)):
            t0 = ts_test[scenario,it0]
            is_in_condition_time = (ts_test[scenario] >= t0) & (ts_test[scenario] < t0 + cond_interval)
            is_in_prediction_time = (ts_test[scenario] >= t0 + cond_interval) & (ts_test[scenario] <= t0 + cond_interval + pred_interval + 1e-4)
            if not is_in_condition_time.any() or not is_in_prediction_time.any():
                continue
            n_fore = is_in_prediction_time.sum()
            indices.append([scenario, it0])
            
            t_cond = ts_test[scenario,is_in_condition_time][...,None]
            t_pred = ts_test[scenario,is_in_prediction_time][...,None]
            x_cond = xs_test[scenario,is_in_condition_time.expand(-1, xs_test.shape[-1])].reshape(-1, xs_test.shape[-1])
            x_truth = xs_test[scenario,is_in_prediction_time.expand(-1, xs_test.shape[-1])].reshape(-1, xs_test.shape[-1])
            
            generation_time0 = time.time()
            xs_pred = pred_func(t_cond, x_cond, t_pred, n_sample)
            generation_time1 = time.time()
            generation_time += (generation_time1 - generation_time0)
            taes, dc, ds = energy_score(x_truth.reshape(1, -1), xs_pred.reshape(n_sample, -1), batch_size=batch_size)
            taes, dc, ds = taes / n_fore**0.5, dc / n_fore**0.5, ds / n_fore**0.5
            taess.append(taes.mean()[None])
            dcs.append(dc.mean()[None])
            dss.append(ds.mean()[None])

        taess = torch.cat(taess)
        dcs = torch.cat(dcs)
        dss = torch.cat(dss)
        indices = torch.tensor(indices, dtype=torch.long)
        generation_time /= taess.shape[0]
        return taess, dcs, dss, indices, generation_time
    


def time_averaged_energy_scores_imputation(ts_test: Tensor, xs_test: Tensor, pred_func, cond_interval1: float=1/6, pred_interval: float=2/3, cond_interval2: float=1/6, n_sample: int=256, batch_size: int=256, n_interval: int=None) -> Tuple[Tensor, Tensor, Tensor, Tensor, float]:
    """
    Calculate time-averaged energy scores for imputation.

    Parameters
    ----------
    ts_test : Tensor
        Time points of the test data. The shape is `(n_scenario, n_time, 1)`.
    xs_test : Tensor
        Time series of the test data. The shape is `(n_scenario, n_time, d)`.
    pred_func : callable
        An imputation model. The arguments must be `ts_cond`, `xs_cond`, `ts_pred`, and `n_sample` and the return value must be `xs_pred`.
        `ts_cond` is a tensor of shape `(n_time_cond, 1)`, where `n_time_cond` is the number of time points in the condition period.
        `xs_cond` is a tensor of shape `(n_time_cond, d)`, where `d` is the dimension of the time series.
        `ts_pred` is a tensor of shape `(n_time_pred, 1)`, where `n_time_pred` is the number of time points in the forecast period.
        `n_sample` is the number of returned samples.
        `xs_pred` is a tensor of shape `(n_sample, n_time_pred, d)`.
    cond_interval1 : float, optional
        The length of the first condition period. The default is 1/6.
    pred_interval : float, optional
        The length of the imputation period. The default is 2/3.
    cond_interval2 : float, optional
        The length of the second condition period. The default is 1/6.
    n_sample : int, optional
        The number of samples for the model. The default is 256.
    batch_size : int, optional
        Batch size for computing the energy score. The default is 256.
    n_interval : int, optional
        The number of evaluated intervals. The default is None

    Returns
    -------
    taess : Tensor
        Time-averaged energy scores for all the valid scenarios and intervals.
    indices : Tensor
        Indices of the valid scenarios and start of intervals.
    generation_time : float
        Average time for generating samples.
    """
    with torch.no_grad():
        if n_interval is not None and n_interval < ts_test.shape[0]*ts_test.shape[1]:
            idx = torch.randint(0, ts_test.shape[0]*ts_test.shape[1], (n_interval,), device=ts_test.device)
            scenarios = idx // ts_test.shape[1]
            it0s = idx % ts_test.shape[1]
        else:
            scenarios = torch.arange(ts_test.shape[0], device=ts_test.device).repeat_interleave(ts_test.shape[1])
            it0s = torch.arange(ts_test.shape[1], device=ts_test.device).repeat(ts_test.shape[0])
        taess = []
        dcs = []
        dss = []
        indices = []
        generation_time = 0
        for scenario, it0 in tqdm(zip(scenarios, it0s), total=len(scenarios)):
            t0 = ts_test[scenario,it0]
            is_in_condition_time1 = (ts_test[scenario] >= t0) & (ts_test[scenario] < t0 + cond_interval1)
            is_in_condition_time2 = (ts_test[scenario] > t0 + cond_interval1 + pred_interval) & (ts_test[scenario] <= t0 + cond_interval1 + pred_interval + cond_interval2 + 1e-4)
            is_in_condition_time = is_in_condition_time1 | is_in_condition_time2
            is_in_prediction_time = (ts_test[scenario] >= t0 + cond_interval1) & (ts_test[scenario] <= t0 + cond_interval1 + pred_interval)
            if not is_in_condition_time1.any() or not is_in_prediction_time.any() or not is_in_condition_time2.any():
                continue
            n_fore = is_in_prediction_time.sum()
            indices.append([scenario, it0])
            
            t_cond = ts_test[scenario,is_in_condition_time][...,None]
            t_pred = ts_test[scenario,is_in_prediction_time][...,None]
            x_cond = xs_test[scenario,is_in_condition_time.expand(-1, xs_test.shape[-1])].reshape(-1, xs_test.shape[-1])
            x_truth = xs_test[scenario,is_in_prediction_time.expand(-1, xs_test.shape[-1])].reshape(-1, xs_test.shape[-1])

            generation_time0 = time.time()
            xs_pred = pred_func(t_cond, x_cond, t_pred, n_sample)
            generation_time1 = time.time()
            generation_time += (generation_time1 - generation_time0)
            taes, dc, ds = energy_score(x_truth.reshape(1, -1), xs_pred.reshape(n_sample, -1), batch_size=batch_size)
            taes, dc, ds = taes / n_fore**0.5, dc / n_fore**0.5, ds / n_fore**0.5
            taess.append(taes.mean()[None])
            dcs.append(dc.mean()[None])
            dss.append(ds.mean()[None])

        taess = torch.cat(taess)
        dcs = torch.cat(dcs)
        dss = torch.cat(dss)
        indices = torch.tensor(indices, dtype=torch.long)
        generation_time /= taess.shape[0]
        return taess, dcs, dss, indices, generation_time



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, help='Config file name.')
    parser.add_argument('--checkpoint', type=str, help='Checkpoint file name.')
    parser.add_argument('-g', action='store_true', help='If specified, evaluation for generation is performed.')
    parser.add_argument('-f', action='store_true', help='If specified, evaluation for forecasting is performed.')
    parser.add_argument('-i', action='store_true', help='If specified, evaluation for imputation is performed.')
    args = parser.parse_args()
    
    if not args.g and not args.f and not args.i:
        print('Please specify -g, -f, or -i. If you want to evaluate generation, specify -g. If you want to evaluate forecasting, specify -f. If you want to evaluate imputation, specify -i.')
        sys.exit()

    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    config_path = os.path.join(parent_dir, 'config', args.config + '.yaml')
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    fix_random(config['seed'])

    outdir = os.path.join(parent_dir, 'output', config['experiment_name'])
    dataset_dir = os.path.join(outdir, 'dataset')
    evaldir = os.path.join(outdir, 'benchmark')
    os.makedirs(evaldir, exist_ok=True)

    # Load test dataset
    xs_test = torch.load(os.path.join(dataset_dir, 'test_x.pt'), weights_only=False)
    ts_test = torch.load(os.path.join(dataset_dir, 'test_t.pt'), weights_only=False)
    xs_test = pad_packed_sequence(xs_test, batch_first=True, padding_value=torch.nan)[0].to(device)
    ts_test = pad_packed_sequence(ts_test, batch_first=True, padding_value=torch.nan)[0].to(device)
    if 'name' not in config['model'] and config['model']['double_precision']:
        xs_test = xs_test.double()
        ts_test = ts_test.double()
    
    # Load model
    suffix = ''
    if 'name' not in config['model']:
        checkpoint_path = os.path.join(outdir, 'checkpoint', args.checkpoint)
        model = OUFlow(xs_test.shape[-1], **config['model']).to(device)
    elif config['model']['name'] == 'GPR':
        checkpoint_path = os.path.join(outdir, 'gpr', 'checkpoint', args.checkpoint)
        likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=xs_test.shape[-1]).to(device)
        likelihood.load_state_dict(torch.load(checkpoint_path, weights_only=False)['likelihood'])
        model = gpr.MultitaskGPModel(None, None, likelihood, xs_test.shape[-1]).to(device)
    elif config['model']['name'] == 'LatentSDE':
        checkpoint_path = os.path.join(outdir, 'latentsde', 'checkpoint', args.checkpoint)
        model = latentsde.LatentSDE(xs_test.shape[-1], config['model']['latent_size'], config['model']['context_size'], config['model']['hidden_size']).to(device)
    elif config['model']['name'] == 'ACSSM':
        checkpoint_path = os.path.join(outdir, 'acssm', 'checkpoint', args.checkpoint)
        model = acssm.ACSSM(config_path, config['model'], config['train'])
    elif config['model']['name'] == 'DSPD':
        suffix = '-' + config['model']['noise']
        checkpoint_path = os.path.join(outdir, 'dspd' + suffix, 'checkpoint', args.checkpoint)
        config['model']['diff_steps'] = config['model']['diff_steps_gen']
        model = dspd.DSPD(xs_test.shape[-1], **config['model'], forecast_interval=config['train']['forecast_interval']).to(device)
    model.load_state_dict(torch.load(checkpoint_path, weights_only=False)['model'])

    torch.cuda.synchronize()

    if args.g:
        n_sample = 4096
        if 'name' not in config['model'] or config['model']['name'] == 'OUFlow':
            config['model']['name'] = 'OUFlow'
            gen_func = ouflow.generate(model)
        elif config['model']['name'] == 'GPR':
            gen_func = gpr.generate(model, likelihood)
        elif config['model']['name'] == 'LatentSDE':
            gen_func = latentsde.generate(model, config['model']['noise_std'])
        elif config['model']['name'] == 'ACSSM':
            gen_func = acssm.generate(model)
        elif config['model']['name'] == 'DSPD':
            gen_func = dspd.generate(model)
        t0 = time.time()
        ed, ed_std, dc, dc_std, ds1, ds1_std, ds2, ds2_std, generation_time = energy_distance_generation(xs_test, gen_func, n_sample=n_sample)
        t1 = time.time()
        with open(os.path.join(evaldir, 'energy_distance_generation.txt'), 'a') as f:
            f.write(f"{config['model']['name'] + suffix}: {ed:.4e} +- {ed_std:.4e} (cross: {dc:.4e} +- {dc_std:.4e}, self_test: {ds1:.4e} +- {ds1_std:.4e}, self_gen: {ds2:.4e} +- {ds2_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)\n")
        print(f"{config['model']['name'] + suffix}: {ed:.4e} +- {ed_std:.4e} (cross: {dc:.4e} +- {dc_std:.4e}, self_test: {ds1:.4e} +- {ds1_std:.4e}, self_gen: {ds2:.4e} +- {ds2_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)")

    if args.f:
        n_sample = 256
        n_interval = 2048
        if 'name' not in config['model'] or config['model']['name'] == 'OUFlow':
            cond_interval = 1/3 * config['train']['loss_train']['forecast_interval']
            pred_interval = 2/3 * config['train']['loss_train']['forecast_interval']
        else:
            cond_interval = 1/3 * config['train']['forecast_interval']
            pred_interval = 2/3 * config['train']['forecast_interval']

        if 'name' not in config['model'] or config['model']['name'] == 'OUFlow':
            config['model']['name'] = 'OUFlow'
            pred_func = ouflow.predict(model)
        elif config['model']['name'] == 'GPR':
            pred_func = gpr.predict(model, likelihood)
        elif config['model']['name'] == 'LatentSDE':
            pred_func = latentsde.predict(model, config['model']['noise_std'])
        elif config['model']['name'] == 'ACSSM':
            pred_func = acssm.predict(model)
        elif config['model']['name'] == 'DSPD':
            pred_func = dspd.predict(model)
        t0 = time.time()
        taess, dcs, dss, indices, generation_time = time_averaged_energy_scores_forecasting(ts_test, xs_test, pred_func, cond_interval=cond_interval, pred_interval=pred_interval, n_sample=n_sample, n_interval=n_interval)
        t1 = time.time()
        taes_mean, taes_std, taes_min, taes_max = taess.mean(), taess.std(), taess.min(), taess.max()
        dc_mean, dc_std = dcs.mean(), dcs.std()
        ds_mean, ds_std = dss.mean(), dss.std()
        with open(os.path.join(evaldir, 'taes_forecasting.txt'), 'a') as f:
            f.write(f"{config['model']['name'] + suffix}: {taes_mean:.4e} +- {taes_std:.4e} (min: {taes_min:.4e}, max: {taes_max:.4e}, cross: {dc_mean:.4e} +- {dc_std:.4e}, self: {ds_mean:.4e} +- {ds_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)\n")
        torch.save({'taess': taess, 'dcs': dcs, 'dss': dss, 'indices': indices}, os.path.join(evaldir, f"taess_forecasting_{config['model']['name'] + suffix}.pt"))
        print(f"{config['model']['name'] + suffix}: {taes_mean:.4e} +- {taes_std:.4e} (min: {taes_min:.4e}, max: {taes_max:.4e}, cross: {dc_mean:.4e} +- {dc_std:.4e}, self: {ds_mean:.4e} +- {ds_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)")

    if args.i:
        n_sample = 256
        n_interval = 2048
        if 'name' not in config['model'] or config['model']['name'] == 'OUFlow':
            cond_interval1 = 1/6 * config['train']['loss_train']['forecast_interval']
            pred_interval = 2/3 * config['train']['loss_train']['forecast_interval']
            cond_interval2 = 1/6 * config['train']['loss_train']['forecast_interval']
        else:
            cond_interval1 = 1/6 * config['train']['forecast_interval']
            pred_interval = 2/3 * config['train']['forecast_interval']
            cond_interval2 = 1/6 * config['train']['forecast_interval']
        if 'name' not in config['model'] or config['model']['name'] == 'OUFlow':
            config['model']['name'] = 'OUFlow'
            pred_func = ouflow.predict(model)
        elif config['model']['name'] == 'GPR':
            pred_func = gpr.predict(model, likelihood)
        elif config['model']['name'] == 'LatentSDE':
            pred_func = latentsde.predict(model, config['model']['noise_std'])
        elif config['model']['name'] == 'ACSSM':
            pred_func = acssm.predict(model)
        elif config['model']['name'] == 'DSPD':
            pred_func = dspd.predict(model)
        t0 = time.time()
        taess, dcs, dss, indices, generation_time = time_averaged_energy_scores_imputation(ts_test, xs_test, pred_func, cond_interval1=cond_interval1, pred_interval=pred_interval, cond_interval2=cond_interval2, n_sample=n_sample, n_interval=n_interval)
        t1 = time.time()
        taes_mean, taes_std, taes_min, taes_max = taess.mean(), taess.std(), taess.min(), taess.max()
        dc_mean, dc_std = dcs.mean(), dcs.std()
        ds_mean, ds_std = dss.mean(), dss.std()
        with open(os.path.join(evaldir, 'taes_imputation.txt'), 'a') as f:
            f.write(f"{config['model']['name'] + suffix}: {taes_mean:.4e} +- {taes_std:.4e} (min: {taes_min:.4e}, max: {taes_max:.4e}, cross: {dc_mean:.4e} +- {dc_std:.4e}, self: {ds_mean:.4e} +- {ds_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)\n")
        torch.save({'taess': taess, 'dcs': dcs, 'dss': dss, 'indices': indices}, os.path.join(evaldir, f"taess_imputation_{config['model']['name'] + suffix}.pt"))
        print(f"{config['model']['name'] + suffix}: {taes_mean:.4e} +- {taes_std:.4e} (min: {taes_min:.4e}, max: {taes_max:.4e}, cross: {dc_mean:.4e} +- {dc_std:.4e}, self: {ds_mean:.4e} +- {ds_std:.4e}) (elapsed time: {t1 - t0:.1f} s, average generation time: {generation_time:.4f} s)")
