import os
from tqdm import tqdm
from typing import Optional

import hydra
import torch
import torch.backends.cudnn as cudnn
import pytorch_lightning as pl
from omegaconf import OmegaConf
from dataclasses import asdict

import wandb

from conf.main_config import GlobalConfiguration

from conf.wandb_params import get_wandb_logger
from data.get_datamodule import get_dm
from utils.Metrics import get_metrics
from utils.evaluation.core import get_prediction_from


@hydra.main(version_base=None, config_name='globalConfiguration')
def main(_cfg: GlobalConfiguration):
    if _cfg.yaml_conf is not None:
        # command line configuration + yaml configuration
        _cfg = OmegaConf.merge(_cfg, OmegaConf.load(_cfg.yaml_conf))

    # command line configuration + yaml configuration + command line configuration
    _cfg = OmegaConf.merge(
        _cfg, {key: val for key, val in OmegaConf.from_cli().items() if '/' not in key})

    print(OmegaConf.to_yaml(_cfg))
    cfg: GlobalConfiguration = OmegaConf.to_object(_cfg)

    pl.seed_everything(cfg.seed)
    if cfg.system_params.torch_params.hub_dir is not None:
        if cfg.system_params.torch_params.hub_dir == 'cwd':
            torch.hub.set_dir(os.path.join(os.getcwd(), 'torch_hub'))
        else:
            torch.hub.set_dir(cfg.system_params.torch_params.hub_dir)

    if cfg.system_params.torch_params.torch_float32_matmul_precision is not None:
        torch.set_float32_matmul_precision(
            cfg.system_params.torch_params.torch_float32_matmul_precision)

    # wandb
    run_wandb = get_wandb_logger(
        params=cfg.wandb_params, global_dict=asdict(cfg))

    # Setup trainer
    dm = get_dm(cfg.dataset_params)
    dm.setup()

    if cfg.trainer_params.cudnn_benchmark is not None:
        cudnn.benchmark = True

    test_dataset = dm.test_dataset
    test_metrics = get_metrics(cfg.model_params.metrics)(cfg.model_params, cfg.dataset_params)

    device = cfg.evaluation_params.device
    test_metrics = test_metrics.to(device)
    # hack_mode: Optional[list[list[int]]] = None  # if not None, set the modes to this during generation, if list randomly populate batch
    mode_cfg = cfg.model_params.logging.hack_mode
    assert mode_cfg and len(mode_cfg) == 1
    mode: torch.tensor = torch.tensor(mode_cfg[0]).float().reshape(1, -1).to(device)
    print(f'{mode=}: {mode.shape=}')

    for data, _mode in tqdm(test_dataset):  # type: ignore
        idx = data[-1]
        data = data[:-1]
        # fetch the prediction image
        prediction = get_prediction_from(params=cfg.evaluation_params, idx=idx)

        prediction = [y.to(device) for y in prediction]
        data = [y.to(device) for y in data]

        # data should have the batch dimension
        prediction = [y.unsqueeze(0) for y in prediction]
        data = [y.unsqueeze(0) for y in data]

        test_metrics.get_dict_generation(
            data=data,
            prediction=prediction,
            mode=mode,
        )

    metrics_dicts = test_metrics.compute_and_get()
    wandb.log(metrics_dicts)

    print(f'<TERMINATE WANDB>')
    wandb.finish()
    print(f'<WANDB TERMINATED>')


if __name__ == '__main__':
    main()
