from pdb import run


import pandas as pd
import pyrootutils
import torch

root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

# ------------------------------------------------------------------------------------ #
# `pyrootutils.setup_root(...)` above is optional line to make environment more convenient
# should be placed at the top of each entry file
#
# main advantages:
# - allows you to keep all entry files in "src/" without installing project as a package
# - launching python file works no matter where is your current work dir
# - automatically loads environment variables from ".env" if exists
#
# how it works:
# - `setup_root()` above recursively searches for either ".git" or "pyproject.toml" in present
#   and parent dirs, to determine the project root dir
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
#   any place without installing project as a package
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
#   to make all paths always relative to project root
# - loads environment variables from ".env" in root dir (if `dotenv=True`)
#
# you can remove `pyrootutils.setup_root(...)` if you:
# 1. either install project as a package or move each entry file to the project root dir
# 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
#
# https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #

from typing import List, Optional, Tuple

import hydra
from omegaconf import DictConfig
import lightning as L
import lightning.pytorch.loggers as loggers

# from src.utils.kfold_loop import KFoldLoop
from src import utils
from src.utils import hydra_resolver, reconstruction
from src.utils.movement_vis import visualize_movement


log = utils.get_pylogger(__name__)
hydra_resolver.register_resolvers(log)


@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
    training.

    This method is wrapped in optional @task_wrapper decorator which applies extra utilities
    before and after the call.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.

    Returns:
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
    """

    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
    datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.datamodule)

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: L.LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating callbacks...")
    callbacks: List[L.Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
    # Reset the wait count in the early stopping callback if it is used (for continuing from checkpoint)
    for callback in callbacks:
        if isinstance(callback, L.pytorch.callbacks.early_stopping.EarlyStopping):
            callback.wait_count = 0

    log.info("Instantiating loggers...")
    logger: List[loggers.Logger] = utils.instantiate_loggers(cfg.get("logger"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: L.Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=logger
    )

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        utils.log_hyperparameters(object_dict)

    if cfg.get("train"):
        log.info("Starting training!")

        # # added code part for k-fold training
        # # check if k-fold is activated
        # kfold = cfg.kfold
        # if kfold:
        #     internal_fit_loop = trainer.fit_loop
        #     trainer.fit_loop = KFoldLoop(kfold.num_folds,
        #                             export_path=kfold.export_path,iModel=iModel)  # type: ignore
        #     trainer.fit_loop.connect(internal_fit_loop)

        trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))

    train_metrics = trainer.callback_metrics

    if cfg.get("test"):
        log.info("Starting testing!")
        ckpt_path = trainer.checkpoint_callback.best_model_path  # type: ignore
        if ckpt_path == "":
            log.warning("Best ckpt not found! Using current weights for testing...")
            ckpt_path = None
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
        log.info(f"Best ckpt path: {ckpt_path}")

    if cfg.get("predict"):
        log.info("Starting prediction!")

        predictions: list[dict[str, torch.Tensor]] = trainer.predict(
            model=model, datamodule=datamodule
        )  # type: ignore

        prediction_runs: dict[str, dict[str, torch.Tensor]] = {}
        i = 0
        #for subject_id in cfg.datamodule.predict_dataset.subject_ids:
        #    for trial_id in cfg.datamodule.predict_dataset.trial_ids:
        #        prediction_runs[f"{subject_id}_{trial_id}"] = predictions[i]
        #        i += 1

        #log.info("Saving predictions to CSV...")
        #utils.save_predictions_to_csv(prediction_runs, cfg)  # type: ignore

        if False: #for run_name, prediction in prediction_runs.items():
            subject_id, trial_id = [int(s) for s in run_name.split("_")]
            print(
                f"Visualizing prediction for subject {subject_id}, trial {trial_id}..."
            )
            print(
                "joint path: ",
                root
                / cfg.datamodule.predict_dataset.data_dir
                / "consts"
                / "joint_list.parquet",
            )
            joints = pd.read_parquet(
                root
                / cfg.datamodule.predict_dataset.data_dir
                / "consts"
                / "joint_list.parquet"
            )
            joints = joints[joints["subject"] == subject_id]
            contact_points = pd.read_parquet(
                root
                / cfg.datamodule.predict_dataset.data_dir
                / "consts"
                / "cp_list.parquet"
            )
            contact_points = contact_points[contact_points["subject"] == subject_id]
            segment_pos = reconstruction.compose_segment_positions(
                joints, contact_points
            )
            IK_data = pd.read_parquet(
                root / cfg.paths.output_dir / f"pred_{run_name}_IK_data.parquet"
            )

            cop_data = None
            if cfg.model.estimated_variables.get("COP_data"):
                cop_data = pd.read_parquet(
                    root / cfg.paths.output_dir / f"pred_{run_name}_COP_data.parquet"
                )

            visualize_movement(
                IK_data.iloc[1500:1750],
                segment_pos,
                cop_data=cop_data.iloc[1500:1750] if cop_data is not None else None,
                output_file=(root / cfg.paths.output_dir / f"animation_{run_name}.gif"),
            )

    test_metrics = trainer.callback_metrics

    # merge train and test metrics
    metric_dict = {**train_metrics, **test_metrics}

    return metric_dict, object_dict


@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
    utils.init_global_cfg(cfg)
    # train the model
    metric_dict, _ = train(cfg)

    # safely retrieve metric value for hydra-based hyperparameter optimization
    metric_value = utils.get_metric_value(
        metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
    )

    # return optimized metric
    return metric_value


if __name__ == "__main__":
    main()
