# test.py
import torch
import numpy as np
import copy
from tqdm import tqdm
import math
from loadData import Dataset_Custom, Dataset_ETTminVR, Dataset_ETThourVR
from config import TrainingConfig


class Tester:
    def __init__(self, config: TrainingConfig, model: torch.nn.Module):
        self.config = config
        self.model = model
        self.device = next(model.parameters()).device
        self.val_dataset = self._load_dataset()

    def _load_dataset(self):
        if 'ETTh' in self.config.data_path:
            return Dataset_ETThourVR(self.config)
        elif 'ETTm' in self.config.data_path:
            return Dataset_ETTminVR(self.config)
        else:
            return Dataset_Custom(self.config)

    def test(self, epoch: Optional[int] = None):
        val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers
        )

        self.model.eval()
        metrics = self._initialize_metrics()

        pbar = tqdm(total=self.config.early_stop)
        count = 0

        with torch.no_grad():
            for batch in val_loader:
                if hasattr(self.config, 'pbar') and self.config.pbar:
                    pbar.update(1)

                count += 1
                self._process_batch(batch, metrics)

        self._report_metrics(metrics, count)
        return metrics['mse'], metrics['mae']

    def _initialize_metrics(self):
        return {
            'mse': [0, 0, 0, 0],
            'mae': [0, 0, 0, 0],
            'mse_exp': [0, 0, 0, 0],
            'mae_exp': [0, 0, 0, 0]
        }

    def _process_batch(self, batch, metrics):
        x, y, d, xseq, yseq, mu, std = batch
        x = x.to(self.device)
        y = y.detach().cpu().numpy()
        yseq = yseq.detach().cpu().numpy()
        mu = mu.detach().cpu().numpy()
        std = std.detach().cpu().numpy()

        ypred_max, ypred_exp = self._test_step(x, y)

        pred_lengths = [96, 192, 336, 720]
        for idx, pred_len in enumerate(pred_lengths):
            self._update_metric_for_length(
                metrics, idx, pred_len,
                ypred_max, ypred_exp, yseq, mu, std
            )

    def _test_step(self, x, y):
        xO = copy.deepcopy(x)
        mask = torch.ones_like(xO)
        mask[:, :, :self.config.size[0], :] = 0
        mask = mask.to(self.device)

        output = self.model(x)
        output = output * mask + xO
        output = output.detach().cpu().numpy()

        output[:, :, :self.config.size[0] - self.config.size[1], :] = y[:, :,
                                                                      :self.config.size[0] - self.config.size[1], :]

        ypred_max = self.val_dataset.Pixel2data(output, method='max')
        ypred_exp = self.val_dataset.Pixel2data(output, method='expection')

        return ypred_max, ypred_exp

    def _update_metric_for_length(self, metrics, idx, pred_len, ypred_max, ypred_exp, yseq, mu, std):
        # Calculate metrics for different prediction lengths
        pred_len_real = pred_len
        pred_len = math.ceil(pred_len / self.config.size[0] * 2)

        ye = (ypred_exp[:, self.config.size[0]:self.config.size[0] + pred_len * 2, :] * std + mu)
        yp = (ypred_max[:, self.config.size[0]:self.config.size[0] + pred_len * 2, :] * std + mu)
        yt = yseq[:, self.config.size[0]:self.config.size[0] + pred_len_real, :]

        if self.config.scal == 2:
            ye = ye[:, 1::2, :]
            yp = yp[:, 1::2, :]

        ye = self.val_dataset.reverse_interpolate_sequence(ye, pred_len_real)
        yp = self.val_dataset.reverse_interpolate_sequence(yp, pred_len_real)

        metrics['mse'][idx] += np.mean((yp - yt) ** 2)
        metrics['mae'][idx] += np.mean(np.abs(yp - yt))
        metrics['mse_exp'][idx] += np.mean((ye - yt) ** 2)
        metrics['mae_exp'][idx] += np.mean(np.abs(ye - yt))

    def _report_metrics(self, metrics, count):
        pred_lengths = [96, 192, 336, 720]
        for idx, pred_len in enumerate(pred_lengths):
            print(f'\nPrediction Length: {pred_len}')
            print(f"MSE (max): {metrics['mse'][idx] / count:.4f}")
            print(f"MAE (max): {metrics['mae'][idx] / count:.4f}")
            print(f"MSE (exp): {metrics['mse_exp'][idx] / count:.4f}")
            print(f"MAE (exp): {metrics['mae_exp'][idx] / count:.4f}")


def test(config: TrainingConfig, model: Optional[torch.nn.Module] = None, epoch: Optional[int] = None):
    tester = Tester(config, model)
    return tester.test(epoch)