import os
import json
import logging
import numpy as np
import pytorch_lightning as pl
from pathlib import Path
from rich.table import Table
from rich.console import Console
from typing import Dict, Optional, Tuple
from mStream.callback import build_callbacks
from mStream.config import parse_args
from mGPT.data.build_data import build_data
from mGPT.models.build_model import build_model
from mGPT.utils.logger import setup_logger
from mGPT.utils.load_checkpoint import load_pretrained, load_pretrained_vae


def print_table(
    title: str, metrics: Dict[str, float], logger: Optional[logging.Logger] = None
) -> None:
    """
    Print a table of metrics using the Rich library and optionally log the metrics.

    Args:
        title (str): The title of the table.
        metrics (Dict[str, float]): A dictionary of metrics to display.
        logger (Optional[logging.Logger]): An optional logger for logging the metrics.
    """
    table = Table(title=title)
    table.add_column("Metrics", style="cyan", no_wrap=True)
    table.add_column("Value", style="magenta")

    for key, value in metrics.items():
        table.add_row(key, str(value))

    console = Console()
    console.print(table, justify="center")

    if logger is not None:
        logger.info("Metrics: %s", metrics)


def get_metric_statistics(
    values: np.ndarray, replication_times: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate the mean and confidence interval of given metric values over multiple replications.

    Args:
        values (np.ndarray): A 2D array of metric values.
        replication_times (int): The number of replications.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The mean and confidence interval of the metrics.
    """
    mean_values = np.mean(values, axis=0)
    std_values = np.std(values, axis=0)
    confidence_interval = 1.96 * std_values / np.sqrt(replication_times)

    return mean_values, confidence_interval


def main():
    # Parse config file
    cfg = parse_args(phase="test")
    cfg.FOLDER = cfg.TEST.FOLDER

    # Set up logger
    logger = setup_logger(cfg, phase="test")

    # Set up output directory
    model_name = cfg.model.target.split(".")[-2].lower()
    output_dir = Path(
        os.path.join(cfg.FOLDER, model_name, cfg.NAME, f"samples_{cfg.TIME}")
    )
    if cfg.TEST.SAVE_PREDICTIONS:
        output_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Saving predictions to {output_dir}")

    # Set random seed for reproducibility
    pl.seed_everything(cfg.SEED_VALUE)

    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    # Initialize callbacks
    callbacks = build_callbacks(cfg, phase="test")
    logger.info("Callbacks initialized")

    # Initialize dataset
    datamodule = build_data(cfg)
    logger.info(f"Dataset module {cfg.DATASET.target.split('.')[-2]} initialized")

    # Initialize model
    model = build_model(cfg, datamodule)
    logger.info(f"Model {cfg.model.target} loaded")

    # Initialize PyTorch Lightning trainer
    trainer = pl.Trainer(
        benchmark=False,
        max_epochs=cfg.TRAIN.END_EPOCH,
        accelerator=cfg.ACCELERATOR,
        devices=list(range(len(cfg.DEVICE))),
        default_root_dir=cfg.FOLDER_EXP,
        reload_dataloaders_every_n_epochs=1,
        deterministic=False,
        detect_anomaly=False,
        enable_progress_bar=True,
        logger=None,
        callbacks=callbacks,
    )

    # Load pretrained models
    load_pretrained(cfg, model, phase="test")
    load_pretrained_vae(cfg, model)
    print(model)

    # Calculate metrics
    all_metrics = {}
    replication_times = cfg.TEST.REPLICATION_TIMES

    for i in range(replication_times):
        mm_flag = False
        metrics_type = ", ".join(cfg.METRIC.TYPE)
        logger.info(f"Evaluating {metrics_type} - Replication {i}")
        metrics = trainer.test(model, datamodule=datamodule)[0]

        if "MotionxMetrics" in metrics_type and i < 5:
            mm_flag = True
        if (
            "TM2TMetrics" in metrics_type
            and cfg.model.params.task == "t2m"
            and cfg.model.params.stage != "vae"
            and cfg.model.params.condition != "pair"
        ):
            mm_flag = True
        if "sr" in cfg.model.target:
            mm_flag = False

        if mm_flag:
            logger.info(f"Evaluating MultiModality - Replication {i}")
            datamodule.mm_mode(True)
            mm_metrics = trainer.test(model, datamodule=datamodule)[0]
            metrics.update(mm_metrics)
            datamodule.mm_mode(False)

        for key, item in metrics.items():
            if key not in all_metrics:
                all_metrics[key] = [item]
            else:
                all_metrics[key].append(item)

    all_metrics_new = {}
    for key, item in all_metrics.items():
        mean, conf_interval = get_metric_statistics(np.array(item), replication_times)
        all_metrics_new[f"{key}/mean"] = mean
        all_metrics_new[f"{key}/conf_interval"] = conf_interval

    print_table("Mean Metrics", all_metrics_new, logger=logger)
    all_metrics_new.update(all_metrics)

    # Save metrics to file
    metric_file = output_dir.parent / f"metrics_{cfg.TIME}.json"
    with open(metric_file, "w", encoding="utf-8") as f:
        json.dump(all_metrics_new, f, indent=4)
    logger.info(f"Testing done, the metrics are saved to {metric_file}")


if __name__ == "__main__":
    main()
