import copy
import os
import numpy as np
from tqdm import tqdm
from typing import Tuple, Any, Dict, Union

import wandb
from pydantic import BaseModel

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import plotly.graph_objects as go
from plotly.colors import qualitative
import plotly.io as pio

from utils.general import set_seed
from utils.predictor import init_predictor
from utils.data import init_relation, DiffusionDataset

from momentfm import MOMENTPipeline
from timer.models.Timer import Model
from timer.configs import timer_imputation_args

# for timer to work on windows
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath


class Trainer:
    class Config(BaseModel):
        # expt
        seed: int = 1
        project_name: str = 'mvts_rhallu'
        expt_name: str = 'test_expt'
        max_epoch: int = 8000
        learning_rate: float = 1e-3
        log_every_n: int = 1
        val_every_n: int = 20
        batch_size: int = 1024

        # dataset
        data_name: str = 'recl'
        input_length: int = 24

        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2

        # predictor
        predictor_name: str = 'default'

        # model
        model_name: str = 'mlp'
        input_size: int = (3 * 24) + 1      # n_sequences * input_length + 1 (plus one for t)
        output_size: int = 3 * 24           # n_sequences * input_length
        hidden_sizes: tuple = (512, 512, 512, 512, 512)

    default_config = Config().dict()

    def __init__(self, config):
        self.config = self.Config(**config).dict()

        # attributes
        self.seed = config['seed']
        self.project_name = config['project_name']
        self.expt_name = config['expt_name']
        self.max_epoch = config['max_epoch']
        self.learning_rate = config['learning_rate']
        self.log_every_n = config['log_every_n']
        self.val_every_n = config['val_every_n']
        self.batch_size = config['batch_size']

        # init
        set_seed(config['seed'])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = init_predictor(config).to(self.device)

    def train(self, use_wandb: bool = True, save_best_val_model: bool = True) -> None:
        """
        Trains the model.

        Args:
            use_wandb (bool, optional): Whether to log training metrics to Weights & Biases.
                                        Defaults to True.
            save_best_val_model (bool, optional): Whether to save the model with the
                                                  best validation loss during training.
                                                  Defaults to True.
        """
        save_dir = os.path.join('expt', 'logs', self.expt_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        else:
            if any(os.scandir(save_dir)):
                raise FileExistsError(f"Directory '{save_dir}' already exists and is not empty.")

        # dataloaders
        train_dataloader = DataLoader(
            DiffusionDataset(self.config, 'train'),
            batch_size=self.config['batch_size'],
            shuffle=True
        )
        val_dataloader = DataLoader(
            DiffusionDataset(self.config, 'val'),
            batch_size=self.config['batch_size'],
            shuffle=False
        )

        # means/stds are computed after the train dataset is initiated, saved with config to be used during inference.
        self.config['data_means'] = train_dataloader.dataset.data_means
        self.config['data_stds'] = train_dataloader.dataset.data_stds

        # steppers
        optimizer = optim.Adam(self.predictor.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, steps_per_epoch=len(train_dataloader), epochs=self.max_epoch)

        # train loop init
        step = 0
        best_val_loss = torch.inf
        best_train_loss = torch.inf

        # loop
        if use_wandb:
            wandb.init(project=self.project_name, config=self.config, name=self.expt_name)
        pbar = tqdm(total=self.max_epoch*len(train_dataloader))
        for epoch in range(self.max_epoch):
            if use_wandb:
                wandb.log({"epoch": epoch}, step=step)

            # train -------------------------
            for x_train, y_train in train_dataloader:
                self.predictor.train()
                optimizer.zero_grad()

                # for when batch_size > dataset size, get extra data
                # add more data in multiple of batch_sizes
                for _ in range((self.batch_size // len(train_dataloader.dataset)) - 1):
                    x_train_, y_train_ = next(iter(train_dataloader))
                    x_train = torch.cat([x_train, x_train_], dim=0)
                    y_train = torch.cat([y_train, y_train_], dim=0)

                # add more data to top up to the exact batch_size requested
                extra_data = self.batch_size - len(x_train)
                if extra_data > 0:
                    x_train_, y_train_ = next(iter(train_dataloader))
                    x_train = torch.cat([x_train, x_train_[0:extra_data]], dim=0)
                    y_train = torch.cat([y_train, y_train_[0:extra_data]], dim=0)

                train_step_output = self.predictor.train_step(x_train.to(self.device), y_train.to(self.device))
                train_loss = train_step_output['loss']
                train_loss.backward()

                optimizer.step()
                scheduler.step()
                if step % self.log_every_n == 0:
                    if use_wandb:
                        wandb.log({'train_loss': train_loss}, step=step)
                step += 1
                pbar.update(1)

                # valid -------------------------
                if step % self.val_every_n == 0:
                    self.predictor.eval()
                    with torch.no_grad():
                        val_loss = 0.0
                        for x_val, y_val in val_dataloader:
                            val_step_output = self.predictor.val_step(x_val.to(self.device), y_val.to(self.device))
                            val_loss += val_step_output['loss'].item()
                        val_loss /= len(val_dataloader)
                        if use_wandb:
                            wandb.log({'val_loss': val_loss}, step=step)

                    # save best model
                    if (val_loss < best_val_loss) and save_best_val_model:
                        best_val_loss = val_loss
                        checkpoint = {
                            'predictor_state_dict': self.predictor.state_dict(),
                            'config': self.config,
                            'info': {
                                'best_epoch': epoch,
                                'best_val_loss': best_val_loss,
                            },
                        }
                        torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth'))
                    elif (train_loss.item() < best_train_loss) and not save_best_val_model:
                        best_train_loss = train_loss.item()
                        checkpoint = {
                            'predictor_state_dict': self.predictor.state_dict(),
                            'config': self.config,
                            'info': {
                                'best_epoch': epoch,
                                'best_val_loss': best_val_loss,
                            },
                        }
                        torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth'))

        # train complete logging
        if use_wandb:
            wandb.log({'best_epoch': checkpoint["info"]['best_epoch']}, step=step)
            wandb.log({'best_val_loss': checkpoint["info"]['best_val_loss']}, step=step)
            wandb.finish()
        return None

    @classmethod
    def load(cls, expt_name: str) -> Tuple[Any, Dict]:
        """
            Loads a Predictor from a checkpoint file and returns a Trainer object along with an info dictionary.

            Args:
                expt_name (str): The name of the experiment that will be used to load the predictor.

            Returns:
                A tuple containing:
                    - The loaded Trainer instance (an instance of the class).
                    - A dictionary containing additional information extracted from the checkpoint file.

            Raises:
                FileNotFoundError: If the checkpoint file cannot be found.
        """

        checkpoint_path = os.path.join('expt', 'logs', expt_name, 'checkpoint.pth')
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found at '{checkpoint_path}'.")

        checkpoint = torch.load(checkpoint_path)
        loaded_orchestrator = cls(config=checkpoint['config'])
        loaded_orchestrator.predictor.load_state_dict(checkpoint['predictor_state_dict'])
        return loaded_orchestrator, checkpoint['info']


class Evaluator:
    class Config(BaseModel):
        # expt
        seed: int = 1
        project_name: str = 'mvts_rhallu'
        expt_name: str = 'test_expt'

        # model
        model_name: str = 'mlp'
        input_size: int = (3 * 24) + 2        # n_sequences * input_length + 1 (plus one for t)
        output_size: int = 3 * 24       # n_sequences * input_length
        hidden_sizes: tuple = (512, 512, 512, 512, 512)

        # dataset
        data_name: str = 'recl'
        input_length: int = 24

        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2
    default_config = Config().dict()
    allowed_data_names = ['recl', 'rwth', 'rtraffic', 'rillness', 'rett']
    allowed_tasks = ['oc', 'uc', 'fc']

    def __init__(self, expt_name):
        checkpoint_path = os.path.join('expt', 'logs', expt_name, 'checkpoint.pth')
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found at '{checkpoint_path}'.")
        self.expt_name = expt_name

        # load
        checkpoint = torch.load(checkpoint_path)
        config = checkpoint['config']
        self.config = config

        # checks
        assert config['data_name'] in self.allowed_data_names, \
            f"data_name={config['data_name']} not in {self.allowed_data_names}."

        # attributes
        self.seed = config['seed']
        self.project_name = config['project_name']
        self.expt_name = config['expt_name']
        self.model_name = config['model_name']
        self.input_size = config['input_size']
        self.output_size = config['output_size']
        self.hidden_sizes = config['hidden_sizes']
        self.input_length = config['input_length']
        self.n_steps = config['n_steps']
        self.beta_first = config['beta_first']
        self.beta_last = config['beta_last']

        # init
        set_seed(self.seed)
        self.relation = init_relation(config)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = init_predictor(config).to(self.device)
        self.predictor.load_state_dict(checkpoint['predictor_state_dict'])
        print(checkpoint['info'])

    def compute_ce_and_re_moment(
            self,
            task: str = 'fc',
            batch_size: int = 1024,
            context_len: int = 24,
            deterministic: bool = True
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
            Computes Combined Error (CE) and Relation Error (RE) for the given task using the MOMENT model.

            Args:
                task (str, optional): The task to evaluate. Defaults to 'fc'.
                batch_size (int, optional): Batch size for evaluation. Defaults to 1024.
                context_len (int, optional): Length of the additional context provided to MOMENT. Defaults to 24.
                deterministic (bool, optional): Whether to use MOMENT deterministically or not. Defaults to True.

            Returns:
                A tuple containing:
                    - An array of CE.
                    - An array of RE.
                    - An array of baseline RE using the weak baseline (predicting the mean only).

        """

        # for making it a fair comparison against TIMER, which requires the first token as context
        config = copy.deepcopy(self.config)     # so that self.config is not overwritten
        config['input_length'] += context_len

        # init dataset
        dataset = DiffusionDataset(config, mode='test')
        data_means = config['data_means']
        data_stds = config['data_stds']

        # init MOMENT model
        moment_model = MOMENTPipeline.from_pretrained(
            "AutonLab/MOMENT-1-large",
            model_kwargs={"task_name": "reconstruction"},
        )
        moment_model.init()
        moment_model = moment_model.to(self.device)
        if not deterministic:
            moment_model.train()

        # format dataloader for MOMENT
        x_prompt_batch, x_gt_batch = [], []
        for idx in np.arange(len(dataset)):
            x_gt = dataset.get_raw_data(idx, concatenate=False).clone()
            x_prompt = x_gt.clone()

            # Set data that will be masked to zero for sanity even though MOMENT will mask it internally using provided
            # mask. No mask for the first 'context_len' bit to make it fair for TIMER which requires one token to
            # be provided as context. Mask the rest accordingly
            if task == 'oc':  # over-constrained
                x_prompt[context_len:, -1] *= 0
            elif task == 'uc':  # under-constrained
                x_prompt[context_len:, 0:-1] *= 0
            elif task == 'fc':  # forecast
                x_prompt[context_len + int(0.5 * (config['input_length'] - context_len))::, :] *= 0
            else:
                raise NotImplementedError(f'task={task} not implemented.')

            # append
            x_prompt_batch.append(x_prompt)
            x_gt_batch.append(x_gt)

        # generate dataloader
        dataloader = DataLoader(
            TensorDataset(
                torch.stack(x_prompt_batch),
                torch.stack(x_gt_batch)
            ), batch_size=batch_size
        )

        # format mask, used my MOMENT
        mask = torch.ones_like(x_gt)  # masked values are 0, observed are 1
        if task == 'oc':  # over-constrained
            mask[context_len:, -1] *= 0  # for MOMENT to not break, need at least one unmasked variable per channel
        elif task == 'uc':  # under-constrained
            mask[context_len:, 0:-1] *= 0  # for MOMENT to not break, need at least one unmasked variable per channel
        elif task == 'fc':  # forecast
            mask[context_len + int(0.5 * (config['input_length'] - context_len))::, :] *= 0
        else:
            raise NotImplementedError(f'task={task} not implemented.')

        # pre-process mask dimensions
        mask = mask.unsqueeze(0)  # [1, l, sl]
        mask = mask.transpose(1, 2)  # [1, sl, l]

        # loop to get relation errors and combined errors
        relation_errors, combined_errors, baseline_relation_errors = [], [], []
        pbar = tqdm(total=len(dataloader), desc=f'MOMENT: computing relation errors and combined errors for task={task}')
        for x, x_gt in dataloader:

            # normalise
            x_prompt = copy.deepcopy(x)
            x_prompt -= data_means
            x_prompt /= data_stds

            # format x
            bs, l, sl = x_prompt.shape
            n_pad = l % 8  # amount to pad to be divisible by 8 - patch size of moment

            # format prompt
            x_prompt = x_prompt.transpose(1, 2)  # [bs, sl, l]
            x_prompt = x_prompt.reshape(-1, 1, l)  # [bs*sl, 1, l]
            x_prompt = torch.nn.functional.pad(x_prompt, (0, n_pad))  # pad with zeros so l is divisible by 8 - patch size of moment        # todo don't need this anymore?

            # format mask
            masks = mask.repeat(bs, 1, 1)  # [bs, l, sl]
            masks = masks.reshape(-1, l)  # [bs*sl, l]
            masks = torch.nn.functional.pad(masks,
                                            (0, n_pad))  # pad with zeros so l is divisible by 8 - patch size of moment

            # get response from model
            with torch.no_grad():
                x_sampled = moment_model(  # [bs, sl, l]
                    x_enc=x_prompt.to(self.device),
                    input_mask=torch.ones_like(masks).to(self.device),
                    mask=masks.to(self.device)
                ).reconstruction.cpu()

            # Reshape back to [bs, l, sl]
            x_sampled = x_sampled.reshape(bs, sl, l + n_pad)  # [bs, sl, l+n_pad]
            x_sampled = x_sampled[:, :, 0:l]  # [bs, sl, l]
            x_sampled = x_sampled.transpose(1, 2)  # [bs, l, sl]
            x_sampled = x_sampled[:, context_len:, :]  # remove context bit

            # un-normalise
            x_sampled *= data_stds
            x_sampled += data_means

            # combined error
            combined_errors += self.predictor.compute_combined_error(x_sampled).tolist()

            # relation error
            relation_gt = self.relation(x_sampled[:, :, 0], x_sampled[:, :, 1])
            relation_error = np.sqrt(((relation_gt - x_sampled[:, :, 2]) ** 2).mean(axis=1))
            relation_errors += relation_error.tolist()

            # baseline relation error
            mean_baseline = data_means[0, 2]
            baseline_relation_error = np.sqrt(((x_gt[:, context_len:, 2] - mean_baseline) ** 2).mean(axis=1))
            baseline_relation_errors += baseline_relation_error.tolist()

            # update
            pbar.update(1)
        pbar.close()
        return np.array(combined_errors), np.array(relation_errors), np.array(baseline_relation_errors)

    def compute_ce_and_re_timer(
            self,
            task:
            str = 'fc',
            batch_size: int = 1024,
            deterministic: bool = True
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
            Computes Combined Error (CE) and Relation Error (RE) for the given task using the TIMER model.

            Args:
                task (str, optional): The task to evaluate. Defaults to 'fc'.
                batch_size (int, optional): Batch size for evaluation. Defaults to 1024.
                deterministic (bool, optional): Whether to use MOMENT deterministically or not. Defaults to True.

            Returns:
                A tuple containing:
                    - An array of CE.
                    - An array of RE.
                    - An array of baseline RE using the weak baseline (predicting the mean only).

        """

        # init TIMER model
        args = timer_imputation_args()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        timer_model = Model(args).float().to(device)
        if not deterministic:
            timer_model.train()

        assert self.config['input_length'] % args.patch_len == 0, \
            f"input_length = {self.config['input_length']} has to be divisible by timer's patch_len = {args.patch_len}."

        # add first token for timer if required
        if self.config['input_length'] <= args.patch_len:
            self.config['input_length'] += args.patch_len

        # init dataset
        dataset = DiffusionDataset(self.config, mode='test')
        data_means = self.config['data_means']
        data_stds = self.config['data_stds']

        # format dataset for Timer
        x_prompt_batch, x_gt_batch = [], []
        for idx in np.arange(len(dataset)):
            x_gt = dataset.get_raw_data(idx, concatenate=False).clone()
            x_prompt = x_gt.clone()

            # no mask for the first token, mask the rest accordingly
            if task == 'oc':  # over-constrained
                x_prompt[args.patch_len:, -1] *= 0
            elif task == 'uc':  # under-constrained
                x_prompt[args.patch_len:, 0:-1] *= 0
            elif task == 'fc':  # forecast
                x_prompt[args.patch_len + int(0.5*(self.config['input_length'] - args.patch_len))::, :] *= 0
            else:
                raise NotImplementedError(f'task={task} not implemented.')

            # append
            x_prompt_batch.append(x_prompt)     # todo data should be [batch_size, length, variables]
            x_gt_batch.append(x_gt)

        # generate dataloader
        dataloader = DataLoader(
            TensorDataset(
                torch.stack(x_prompt_batch),
                torch.stack(x_gt_batch)
            ), batch_size=batch_size
        )

        # format mask, used my TIMER
        mask = torch.ones_like(x_gt)    # masked values are 0, observed are 1
        if task == 'oc':                # over-constrained
            mask[args.patch_len:, -1] *= 0          # for MOMENT to not break, need at least one unmasked variable per channel
        elif task == 'uc':              # under-constrained
            mask[args.patch_len:, 0:-1] *= 0        # for MOMENT to not break, need at least one unmasked variable per channel
        elif task == 'fc':              # forecast
            mask[args.patch_len + int(0.5*(self.config['input_length'] - args.patch_len))::, :] *= 0
        else:
            raise NotImplementedError(f'task={task} not implemented.')
        mask = mask.unsqueeze(0).to(self.device)  # [1, l, sl]

        # loop to get relation errors and combined errors
        relation_errors, combined_errors, baseline_relation_errors = [], [], []
        pbar = tqdm(total=len(dataloader), desc=f'MOMENT: computing relation errors and combined errors for task={task}')
        for x, x_gt in dataloader:
            bs, l, _ = x.shape

            # normalise
            x_prompt = copy.deepcopy(x)
            x_prompt -= data_means
            x_prompt /= data_stds

            # get response from model
            with torch.no_grad():
                x_sampled = timer_model(  # [bs, l, sl]
                    x_prompt.to(self.device), None, None, None, mask.repeat(bs, 1, 1)
                ).cpu()
                x_sampled = x_sampled[:, args.patch_len:, :]        # remove first token

            # un-normalise
            x_sampled *= data_stds
            x_sampled += data_means

            # combined error
            combined_errors += self.predictor.compute_combined_error(x_sampled).tolist()

            # relation error
            relation_gt = self.relation(x_sampled[:, :, 0], x_sampled[:, :, 1])
            relation_error = np.sqrt(((relation_gt - x_sampled[:, :, 2]) ** 2).mean(axis=1))
            relation_errors += relation_error.tolist()

            # baseline relation error
            mean_baseline = data_means[0, 2]
            baseline_relation_error = np.sqrt(((x_gt[:, args.patch_len:, 2] - mean_baseline) ** 2).mean(axis=1))
            baseline_relation_errors += baseline_relation_error.tolist()

            # update
            pbar.update(1)
        pbar.close()
        return np.array(combined_errors), np.array(relation_errors), np.array(baseline_relation_errors)

    def compute_ce_and_re_dm(
            self,
            task: str = 'fc',
            batch_size: int = 1024,
            delta: float = 0.0,
            mode: str = 'test',
            n_data: Union[int, None] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
            Computes Combined Error (CE) and Relation Error (RE) for the given task using the diffusion model.

            Args:
                task (str, optional): The task to evaluate. Defaults to 'fc'.
                batch_size (int, optional): Batch size for evaluation. Defaults to 1024.
                delta (float, optional): The offset to be added to the first variable (offset OOD variation)
                    Defaults to 0.0.
                mode (str, optional): The dataset mode to use (train, val or test). Defaults to 'test'.
                n_data (int or None, optional): The number of data to use for the computation, will subsample dataset
                    by using equally spaced intervals. This is to make computations faster. Set to None to use all data.
                    Defaults to None.

            Returns:
                A tuple containing:
                    - An array of CE.
                    - An array of RE.
                    - An array of baseline RE using the weak baseline (predicting the mean only).
        """

        # init dataset
        dataset = DiffusionDataset(self.config, mode=mode)
        data_means = self.config['data_means']
        data_stds = self.config['data_stds']

        # sample data, to reduce size when plotting:
        if n_data is not None and n_data < len(dataset):
            sampled_indices = np.linspace(0, len(dataset) - 1, n_data, dtype=int)
        else:
            sampled_indices = np.arange(len(dataset))

        # format dataloader for DM
        x_prompt_batch, x_gt_batch = [], []
        for idx in sampled_indices:
            x_gt = dataset.get_raw_data(idx, concatenate=False).clone()

            # mask based on task
            if task == 'oc':  # over-constrained
                # add delta
                x_gt[:, 0] += delta

                # update x_gt relations
                x_gt[:, 2] = self.relation(x_gt[:, 0], x_gt[:, 1])

                # mask
                x_prompt = x_gt.clone()
                x_prompt[:, -1] *= np.nan

            elif task == 'uc':  # under-constrained
                # add delta
                x_gt[:, -1] += delta

                # mask
                x_prompt = x_gt.clone()
                x_prompt[:, 0:-1] *= np.nan

            elif task == 'fc':  # forecast
                # add delta
                x_gt[:, 0] += delta

                # update x_gt relations
                x_gt[:, 2] = self.relation(x_gt[:, 0], x_gt[:, 1])

                # mask
                x_prompt = x_gt.clone()
                x_prompt[int(self.config['input_length'] * 0.5)::, :] *= np.nan

            else:
                raise NotImplementedError(f'task={task} not implemented.')

            # append
            x_prompt_batch.append(x_prompt)
            x_gt_batch.append(x_gt)

        # generate dataloader
        dataloader = DataLoader(
            TensorDataset(
                torch.stack(x_prompt_batch),
                torch.stack(x_gt_batch)
            ), batch_size=batch_size
        )

        # loop to generate prompt and response metrics
        relation_errors, combined_errors, baseline_relation_errors = [], [], []
        pbar = tqdm(total=len(dataloader), desc=f'DM: computing relation errors and combined errors for task={task}')
        for x, x_gt in dataloader:
            # repeat as required
            bs, l, sl = x.shape

            # sample given batch
            x_sampled, info = self.predictor.guided_sample(x_prompt=x)      # todo remove this info stuff

            # combined error
            combined_errors += self.predictor.compute_combined_error(x_sampled).tolist()

            # relation error
            relation_gt = self.relation(x_sampled[:, :, 0], x_sampled[:, :, 1])
            relation_error = np.sqrt(((relation_gt - x_sampled[:, :, 2]) ** 2).mean(axis=1))
            relation_errors += relation_error.tolist()

            # baseline relation error
            mean_baseline = data_means[0, 2][None, None].repeat(bs, l)
            baseline_relation_error = np.sqrt(((x_gt[:, :, 2] - mean_baseline) ** 2).mean(axis=1))
            baseline_relation_errors += baseline_relation_error.tolist()

            # update
            pbar.update(1)
        pbar.close()
        return np.array(combined_errors), np.array(relation_errors), np.array(baseline_relation_errors)

    def compute_m_and_re_dm(
            self,
            task: str = 'fc',
            batch_size: int = 1024,
            delta: float = 0.0,
            mode: str = 'test',
            n_data: Union[int, None] = None
    ) -> Dict[str, np.ndarray]:
        """
            Computes unused metrics and Relation Error (RE) for the given task using the diffusion model.

            Args:
                task (str, optional): The task to evaluate. Defaults to 'fc'.
                batch_size (int, optional): Batch size for evaluation. Defaults to 1024.
                delta (float, optional): The offset to be added to the first variable (offset OOD variation)
                    Defaults to 0.0.
                mode (str, optional): The dataset mode to use (train, val or test). Defaults to 'test'.
                n_data (int or None, optional): The number of data to use for the computation, will subsample dataset
                    by using equally spaced intervals. This is to make computations faster. Set to None to use all data.
                    Defaults to None.

            Returns:
                A dictionary containing an array of metrics. The keys are the abbreviations of the metrics.
        """

        # init dataset
        dataset = DiffusionDataset(self.config, mode=mode)
        data_means = self.config['data_means']
        data_stds = self.config['data_stds']

        # sample data, to reduce size when plotting:
        if n_data is not None and n_data < len(dataset):
            sampled_indices = np.linspace(0, len(dataset) - 1, n_data, dtype=int)
        else:
            sampled_indices = np.arange(len(dataset))

        # format dataloader for DM
        x_prompt_batch, x_gt_batch = [], []
        for idx in sampled_indices:
            x_gt = dataset.get_raw_data(idx, concatenate=False).clone()

            # mask based on task
            if task == 'oc':  # over-constrained
                # add delta
                x_gt[:, 0] += delta

                # update x_gt relations
                x_gt[:, 2] = self.relation(x_gt[:, 0], x_gt[:, 1])

                # mask
                x_prompt = x_gt.clone()
                x_prompt[:, -1] *= np.nan

            elif task == 'uc':  # under-constrained
                # add delta
                x_gt[:, -1] += delta

                # mask
                x_prompt = x_gt.clone()
                x_prompt[:, 0:-1] *= np.nan

            elif task == 'fc':  # forecast
                # add delta
                x_gt[:, 0] += delta

                # update x_gt relations
                x_gt[:, 2] = self.relation(x_gt[:, 0], x_gt[:, 1])

                # mask
                x_prompt = x_gt.clone()
                x_prompt[int(self.config['input_length'] * 0.5)::, :] *= np.nan

            else:
                raise NotImplementedError(f'task={task} not implemented.')

            # append
            x_prompt_batch.append(x_prompt)
            x_gt_batch.append(x_gt)

        # generate dataloader
        dataloader = DataLoader(
            TensorDataset(
                torch.stack(x_prompt_batch),
                torch.stack(x_gt_batch)
            ), batch_size=batch_size
        )

        # loop to generate prompt and response metrics
        relation_errors, combined_errors, baseline_relation_errors = [], [], []
        pts_list, rts_list, cts_list, pe_list = [], [], [], []
        x_sampled_ = []     # for combined ts
        pbar = tqdm(total=len(dataloader), desc=f'DM: computing relation errors and combined errors for task={task}')
        for x, x_gt in dataloader:
            # repeat as required
            bs, l, sl = x.shape

            # sample given batch
            x_sampled, info = self.predictor.guided_sample(x_prompt=x)
            x_sampled_.append(x_sampled)

            # prompt and response metrics
            pts_list += info['prompt_ts_list']
            pe_list += info['prompt_error_list']
            rts_list += info['response_ts_list']

            # combined error
            combined_errors += self.predictor.compute_combined_error(x_sampled).tolist()

            # relation error
            relation_gt = self.relation(x_sampled[:, :, 0], x_sampled[:, :, 1])
            relation_error = np.sqrt(((relation_gt - x_sampled[:, :, 2]) ** 2).mean(axis=1))
            relation_errors += relation_error.tolist()

            # baseline relation error
            mean_baseline = data_means[0, 2][None, None].repeat(bs, l)
            baseline_relation_error = np.sqrt(((x_gt[:, :, 2] - mean_baseline) ** 2).mean(axis=1))
            baseline_relation_errors += baseline_relation_error.tolist()

            # update
            pbar.update(1)
        pbar.close()

        # loop to generate combined ts
        dataloader = DataLoader(torch.cat(x_sampled_), batch_size=batch_size)
        pbar = tqdm(total=len(dataloader), desc='computing combined ts')
        for x in dataloader:
            x_sampled, info = self.predictor.guided_sample(x_prompt=x)
            cts_list += info['prompt_ts_list']
            pbar.update(1)
        pbar.close()

        return {
            'pts': np.array(pts_list),
            'pe': np.array(pe_list),
            'rts': np.array(rts_list),
            'cts': np.array(cts_list),
            'ce': np.array(combined_errors),
            're': np.array(relation_errors)
        }

    def plot_ce_v_re(
            self,
            prompt_delta_s1_list: np.ndarray,
            task: str = 'oc',
            batchsize: int = 4096,
            n_data: Union[int, None]=None,
            save: bool = True,
            log_dir_name: str = 'plots',
            mode: str = 'train'
    ) -> go.Figure:
        """
        Plots the Combined Error (CE) against the Relation Error (RE).

        Args:
            prompt_delta_s1_list (np.ndarray): A NumPy array of offset delta values.
            task (str, optional): The task to evaluate. Defaults to 'oc'.
            batchsize (int, optional): Batch size for evaluation. Defaults to 4096.
            n_data (Union[int, None], optional): The number of data to use for the computation, will subsample dataset
                by using equally spaced intervals. This is to make computations faster. Set to None to use all data.
                Defaults to None.
            save (bool, optional): Whether to save the plot. Defaults to True.
            log_dir_name (str, optional): Directory name (not path) to save the plot in. Defaults to 'plots'.
            mode (str, optional):  The dataset mode to use (train, val or test). Defaults to 'train'.

        Returns:
            A Plotly figure object.
        """

        # log_dir
        log_dir = os.path.join('expt', 'logs', self.expt_name, log_dir_name)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        deltas, combined_errors, relation_errors = [], [], []
        for delta in prompt_delta_s1_list:
            ce, re, _ = self.compute_ce_and_re_dm(task, batchsize, delta, mode, n_data)

            # combine for all offsets
            combined_errors += ce.tolist()
            relation_errors += re.tolist()
            deltas += (delta * np.ones_like(ce)).tolist()

        # convert to array
        combined_errors = np.array(combined_errors)
        relation_errors = np.array(relation_errors)
        deltas = np.array(deltas)

        # deal with colours
        cmap = qualitative.Plotly
        min_val = min(prompt_delta_s1_list)
        if len(np.unique(deltas)) == 1:
            normalized_vals = np.zeros_like(deltas).astype(int)
        else:
            normalized_vals = np.floor(len(prompt_delta_s1_list) * (deltas-min_val) / max((deltas-min_val))).astype(int)
        colors = [cmap[i] for i in normalized_vals]

        # plot
        fig = go.Figure(
            go.Scatter(
                x=combined_errors, y=relation_errors, mode='markers',
                marker=dict(
                    size=2,
                    color=colors,
                    opacity=1.0
                ),
                text=[f'delta s1: {val:.2f}' for val in deltas],
                hoverinfo='text'
            )
        )
        fig.update_layout(
            title=f'Task={task.upper()}: Combined Error vs Relation Error.',
            xaxis_title='Combined Error',
            yaxis_title='Relation Error'
        )

        if save:
            save_path = os.path.join(log_dir, f"ce_v_re_{task}.html")
            pio.write_html(fig, file=save_path, auto_open=False)

        return fig

    def plot_m_v_re(
            self,
            prompt_delta_s1_list: np.ndarray,
            task: str = 'oc',
            batchsize: int = 4096,
            n_data: Union[int, None] = None,
            save: bool = True,
            log_dir_name: str = 'plots',
            mode: str = 'train'
    ) -> Dict[str, go.Figure]:
        """
        Plots the different unused metrics against the Relation Error (RE).

        Args:
            prompt_delta_s1_list (np.ndarray): A NumPy array of offset delta values.
            task (str, optional): The task to evaluate. Defaults to 'oc'.
            batchsize (int, optional): Batch size for evaluation. Defaults to 4096.
            n_data (Union[int, None], optional): The number of data to use for the computation, will subsample dataset
                by using equally spaced intervals. This is to make computations faster. Set to None to use all data.
                Defaults to None.
            save (bool, optional): Whether to save the plot. Defaults to True.
            log_dir_name (str, optional): Directory name (not path) to save the plot in. Defaults to 'plots'.
            mode (str, optional):  The dataset mode to use (train, val or test). Defaults to 'train'.

        Returns:
            A dictionary of plotly figures. The keys are abbreviations of the corresponding metric.
        """

        # log_dir
        log_dir = os.path.join('expt', 'logs', self.expt_name, log_dir_name)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        deltas, combined_errors, relation_errors = [], [], []
        pts_list, rts_list, cts_list, pe_list = [], [], [], []
        for delta in prompt_delta_s1_list:
            output = self.compute_m_and_re_dm(task, batchsize, delta, mode, n_data)

            # combine for all offsets
            combined_errors += output['ce'].tolist()
            relation_errors += output['re'].tolist()
            deltas += (delta * np.ones_like(output['ce'])).tolist()

            # other metrics
            pts_list += output['pts'].tolist()
            rts_list += output['rts'].tolist()
            cts_list += output['cts'].tolist()
            pe_list += output['pe'].tolist()

        # convert to array
        relation_errors = np.array(relation_errors)
        deltas = np.array(deltas)

        results = {
            'pts': np.array(pts_list),
            'rts': np.array(rts_list),
            'cts': np.array(cts_list),
            'pe': np.array(pe_list),
            'ce': np.array(combined_errors),
        }

        figs = {}
        for metric_name, metric in results.items():
            # deal with colours
            cmap = qualitative.Plotly
            min_val = min(prompt_delta_s1_list)
            if len(np.unique(deltas)) == 1:
                normalized_vals = np.zeros_like(deltas).astype(int)
            else:
                normalized_vals = np.floor(len(prompt_delta_s1_list) * (deltas-min_val) / max((deltas-min_val))).astype(int)
            colors = [cmap[i] for i in normalized_vals]

            # plot
            fig = go.Figure(
                go.Scatter(
                    x=metric, y=relation_errors, mode='markers',
                    marker=dict(
                        size=2,
                        color=colors,
                        opacity=1.0
                    ),
                    text=[f'delta s1: {val:.2f}' for val in deltas],
                    hoverinfo='text'
                )
            )
            fig.update_layout(
                title=f'Task={task.upper()}: {metric_name.upper()} vs Relation Error.',
                xaxis_title=f'{metric_name.upper()}',
                yaxis_title='Relation Error'
            )

            if save:
                save_path = os.path.join(log_dir, f"M{metric_name}_v_re_{task}.html")
                pio.write_html(fig, file=save_path, auto_open=False)

            figs[metric_name] = fig
        return figs


class TrainerRContour:
    class Config(BaseModel):
        # expt
        seed: int = 1
        project_name: str = 'mvts_rhallu'
        expt_name: str = 'test_expt'
        max_epoch: int = 1000
        learning_rate: float = 1e-2
        log_every_n: int = 1
        val_every_n: int = 20
        batch_size: int = 256

        # g_diff
        n_steps: int = 1000
        beta_first: float = 1e-4
        beta_last: float = 2e-2

        # predictor
        predictor_name: str = 'default'

        # model
        model_name: str = 'mlp'
        input_size: int = 2 + 1
        output_size: int = 2
        hidden_sizes: tuple = (128, 128, 128, 128, 128)

    default_config = Config().dict()

    def __init__(self, config):
        self.config = self.Config(**config).dict()

        # attributes
        self.seed = config['seed']
        self.project_name = config['project_name']
        self.expt_name = config['expt_name']
        self.max_epoch = config['max_epoch']
        self.learning_rate = config['learning_rate']
        self.log_every_n = config['log_every_n']
        self.val_every_n = config['val_every_n']
        self.batch_size = config['batch_size']

        # init
        set_seed(config['seed'])
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.predictor = init_predictor(config).to(self.device)

    def train(self, dataset_train, dataset_val, use_wandb: bool = True, save_best_val_model: bool = True) -> None:
        """
        Trains the model.

        Args:
            use_wandb (bool, optional): Whether to log training metrics to Weights & Biases.
                                        Defaults to True.
            save_best_val_model (bool, optional): Whether to save the model with the
                                                  best validation loss during training.
                                                  Defaults to True.
        """
        save_dir = os.path.join('expt', 'logs', self.expt_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        else:
            if any(os.scandir(save_dir)):
                raise FileExistsError(f"Directory '{save_dir}' already exists and is not empty.")

        # dataloaders
        train_dataloader = DataLoader(
            dataset_train,
            batch_size=self.config['batch_size'],
            shuffle=True
        )
        val_dataloader = DataLoader(
            dataset_val,
            batch_size=self.config['batch_size'],
            shuffle=False
        )

        # save train datapoints in config
        self.config['data_train'] = dataset_train.sequences.tolist()

        # steppers
        optimizer = optim.Adam(self.predictor.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.learning_rate, steps_per_epoch=len(train_dataloader), epochs=self.max_epoch)

        # train loop init
        step = 0
        best_val_loss = torch.inf
        best_train_loss = torch.inf

        # loop
        if use_wandb:
            wandb.init(project=self.project_name, config=self.config, name=self.expt_name)
        pbar = tqdm(total=self.max_epoch*len(train_dataloader))
        for epoch in range(self.max_epoch):
            if use_wandb:
                wandb.log({"epoch": epoch}, step=step)

            # train -------------------------
            for x_train, y_train in train_dataloader:
                self.predictor.train()
                optimizer.zero_grad()

                # for when batch_size > dataset size, get extra data
                # add more data in multiple of batch_sizes
                for _ in range((self.batch_size // len(train_dataloader.dataset)) - 1):
                    x_train_, y_train_ = next(iter(train_dataloader))
                    x_train = torch.cat([x_train, x_train_], dim=0)
                    y_train = torch.cat([y_train, y_train_], dim=0)

                # add more data to top up to the exact batch_size requested
                extra_data = self.batch_size - len(x_train)
                if extra_data > 0:
                    x_train_, y_train_ = next(iter(train_dataloader))
                    x_train = torch.cat([x_train, x_train_[0:extra_data]], dim=0)
                    y_train = torch.cat([y_train, y_train_[0:extra_data]], dim=0)

                train_step_output = self.predictor.train_step(x_train.to(self.device), y_train.to(self.device))
                train_loss = train_step_output['loss']
                train_loss.backward()

                optimizer.step()
                scheduler.step()
                if step % self.log_every_n == 0:
                    if use_wandb:
                        wandb.log({'train_loss': train_loss}, step=step)
                step += 1
                pbar.update(1)

                # valid -------------------------
                if step % self.val_every_n == 0:
                    self.predictor.eval()
                    with torch.no_grad():
                        val_loss = 0.0
                        for x_val, y_val in val_dataloader:
                            val_step_output = self.predictor.val_step(x_val.to(self.device), y_val.to(self.device))
                            val_loss += val_step_output['loss'].item()
                        val_loss /= len(val_dataloader)
                        if use_wandb:
                            wandb.log({'val_loss': val_loss}, step=step)

                    # save best model
                    if (val_loss < best_val_loss) and save_best_val_model:
                        best_val_loss = val_loss
                        checkpoint = {
                            'predictor_state_dict': self.predictor.state_dict(),
                            'config': self.config,
                            'info': {
                                'best_epoch': epoch,
                                'best_val_loss': best_val_loss,
                            },
                        }
                        torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth'))
                    elif (train_loss.item() < best_train_loss) and not save_best_val_model:
                        best_train_loss = train_loss.item()
                        checkpoint = {
                            'predictor_state_dict': self.predictor.state_dict(),
                            'config': self.config,
                            'info': {
                                'best_epoch': epoch,
                                'best_val_loss': best_val_loss,
                            },
                        }
                        torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth'))

        # train complete logging
        if use_wandb:
            wandb.log({'best_epoch': checkpoint["info"]['best_epoch']}, step=step)
            wandb.log({'best_val_loss': checkpoint["info"]['best_val_loss']}, step=step)
            wandb.finish()
        return None

    @classmethod
    def load(cls, expt_name: str) -> Tuple[Any, Dict]:
        """
            Loads a Predictor from a checkpoint file and returns a Trainer object along with an info dictionary.

            Args:
                expt_name (str): The name of the experiment that will be used to load the predictor.

            Returns:
                A tuple containing:
                    - The loaded Trainer instance (an instance of the class).
                    - A dictionary containing additional information extracted from the checkpoint file.
                    - List of data used for training.

            Raises:
                FileNotFoundError: If the checkpoint file cannot be found.
        """

        checkpoint_path = os.path.join('expt', 'logs', expt_name, 'checkpoint.pth')
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint file not found at '{checkpoint_path}'.")

        checkpoint = torch.load(checkpoint_path)
        loaded_orchestrator = cls(config=checkpoint['config'])
        loaded_orchestrator.predictor.load_state_dict(checkpoint['predictor_state_dict'])
        return loaded_orchestrator, checkpoint['info'], checkpoint['config']['data_train']
