from logging.config import dictConfig
import math
import operator
import os
import time
import warnings
import subprocess
from importlib.util import find_spec
from pathlib import Path
from typing import Callable, List, Optional

import hydra
from omegaconf import DictConfig
import lightning as L
from lightning.pytorch.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only
import torch
import pandas as pd

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)

IK_column_mapping = {
    "tx": "pos_pelvis_tx",
    "dtx": "vel_pelvis_tx",
    "ddtx": "acc_pelvis_tx",
    "ty": "pos_pelvis_ty",
    "dty": "vel_pelvis_ty",
    "ddty": "acc_pelvis_ty",
    "a_pelvis": "pos_pelvis_tilt",
    "da_pelvis": "vel_pelvis_tilt",
    "dda_pelvis": "acc_pelvis_tilt",
    "a_hip_r": "pos_hip_flexion_r",
    "da_hip_r": "vel_hip_flexion_r",
    "dda_hip_r": "acc_hip_flexion_r",
    "a_knee_r": "pos_knee_angle_r",
    "da_knee_r": "vel_knee_angle_r",
    "dda_knee_r": "acc_knee_angle_r",
    "a_ankle_r": "pos_ankle_angle_r",
    "da_ankle_r": "vel_ankle_angle_r",
    "dda_ankle_r": "acc_ankle_angle_r",
    "a_hip_l": "pos_hip_flexion_l",
    "da_hip_l": "vel_hip_flexion_l",
    "dda_hip_l": "acc_hip_flexion_l",
    "a_knee_l": "pos_knee_angle_l",
    "da_knee_l": "vel_knee_angle_l",
    "dda_knee_l": "acc_knee_angle_l",
    "a_ankle_l": "pos_ankle_angle_l",
    "da_ankle_l": "vel_ankle_angle_l",
    "dda_ankle_l": "acc_ankle_angle_l",
    "a_lumbar": "pos_lumbar_extension",
    "da_lumbar": "vel_lumbar_extension",
    "dda_lumbar": "acc_lumbar_extension",
    "torque_hip_r": "",
    "torque_knee_r": "",
    "torque_ankle_r": "",
    "torque_hip_l": "",
    "torque_knee_l": "",
    "torque_ankle_l": "",
    "gf_rx": "",
    "gf_ry": "",
    "gf_lx": "",
    "gf_ly": "",
    "cops_rx": "",
    "cops_ry": "",
    "cops_lx": "",
    "cops_ly": "",
}

IK_columns = list(IK_column_mapping.keys())


GRF_COP_column_mapping = {
    "grf_rx": "ground_force_vx",
    "grf_ry": "ground_force_vy",
    "grf_lx": "1_ground_force_vx",
    "grf_ly": "1_ground_force_vy",
    "cop_rx": "ground_force_px",
    "cop_ry": "ground_force_py",
    "cop_lx": "1_ground_force_px",
    "cop_ly": "1_ground_force_py",
}

GRF_columns = list(
    filter(lambda col: col.startswith("grf"), list(GRF_COP_column_mapping.keys()))
)

COP_columns = list(
    filter(lambda col: col.startswith("cop"), list(GRF_COP_column_mapping.keys()))
)


def init_global_cfg(cfg: DictConfig) -> None:
    global GLOBAL_CFG
    GLOBAL_CFG = cfg


def task_wrapper(task_func: Callable) -> Callable:
    """Optional decorator that wraps the task function in extra utilities.

    Makes multirun more resistant to failure.

    Utilities:
    - Calling the `utils.extras()` before the task is started
    - Calling the `utils.close_loggers()` after the task is finished
    - Logging the exception if occurs
    - Logging the task total execution time
    - Logging the output dir
    """

    def wrap(cfg: DictConfig):
        # apply extra utilities
        extras(cfg)

        # execute the task
        try:
            start_time = time.time()
            metric_dict, object_dict = task_func(cfg=cfg)
        except Exception as ex:
            log.exception("")  # save exception to `.log` file
            raise ex
        finally:
            path = Path(cfg.paths.output_dir, "exec_time.log")
            content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"  # type: ignore
            save_file(
                path, content  # type: ignore
            )  # save task execution time (even if exception occurs)
            close_loggers()  # close loggers (even if exception occurs so multirun won't fail)

        log.info(f"Output dir: {cfg.paths.output_dir}")

        return metric_dict, object_dict

    return wrap


def extras(cfg: DictConfig) -> None:
    """Applies optional utilities before the task is started.

    Utilities:
    - Ignoring python warnings
    - Setting tags from command line
    - Rich config printing
    """

    # return if no `extras` config
    if not cfg.get("extras"):
        log.warning("Extras config not found! <cfg.extras=null>")
        return

    # disable python warnings
    if cfg.extras.get("ignore_warnings"):
        log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
        warnings.filterwarnings("ignore")

    # prompt user to input tags from command line if none are provided in the config
    if cfg.extras.get("enforce_tags"):
        log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
        rich_utils.enforce_tags(cfg, save_to_file=True)

    # pretty print config tree using Rich library
    if cfg.extras.get("print_config"):
        log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
        rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)


@rank_zero_only
def save_file(path: str, content: str) -> None:
    """Save file in rank zero mode (only on one process in multi-GPU setup)."""
    with open(path, "w+") as file:
        file.write(content)


def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[L.Callback]:
    """Instantiates callbacks from config."""
    callbacks: List[L.Callback] = []

    if not callbacks_cfg:
        log.warning("No callback configs found! Skipping..")
        return callbacks

    if not isinstance(callbacks_cfg, DictConfig):
        raise TypeError("Callbacks config must be a DictConfig!")

    for _, cb_conf in callbacks_cfg.items():
        if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
            log.info(f"Instantiating callback <{cb_conf._target_}>")
            callbacks.append(hydra.utils.instantiate(cb_conf))

    return callbacks


def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
    """Instantiates loggers from config."""
    logger: List[Logger] = []

    if not logger_cfg:
        log.warning("No logger configs found! Skipping...")
        return logger

    if not isinstance(logger_cfg, DictConfig):
        raise TypeError("Logger config must be a DictConfig!")

    for _, lg_conf in logger_cfg.items():
        if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
            log.info(f"Instantiating logger <{lg_conf._target_}>")
            logger.append(hydra.utils.instantiate(lg_conf))

    return logger


@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
    """Controls which config parts are saved by lightning loggers.

    Additionally saves:
    - Number of model parameters
    """

    hparams = {}

    cfg = object_dict["cfg"]
    model = object_dict["model"]
    trainer = object_dict["trainer"]

    if not trainer.logger:
        log.warning("Logger not found! Skipping hyperparameter logging...")
        return

    hparams["model"] = cfg["model"]

    # save number of model parameters
    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params/trainable"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    hparams["model/params/non_trainable"] = sum(
        p.numel() for p in model.parameters() if not p.requires_grad
    )

    hparams["datamodule"] = cfg["datamodule"]
    hparams["trainer"] = cfg["trainer"]

    hparams["callbacks"] = cfg.get("callbacks")
    hparams["extras"] = cfg.get("extras")

    hparams["task_name"] = cfg.get("task_name")
    hparams["tags"] = cfg.get("tags")
    hparams["ckpt_path"] = cfg.get("ckpt_path")
    hparams["seed"] = cfg.get("seed")

    # send hparams to all loggers
    for logger in trainer.loggers:
        logger.log_hyperparams(hparams)


def get_metric_value(metric_dict: dict, metric_name: str) -> float:
    """Safely retrieves value of the metric logged in LightningModule."""

    if not metric_name:
        log.info("Metric name is None! Skipping metric value retrieval...")
        return None  # type: ignore

    if metric_name not in metric_dict:
        raise Exception(
            f"Metric value not found! <metric_name={metric_name}>\n"
            "Make sure metric name logged in LightningModule is correct!\n"
            "Make sure `optimized_metric` name in `hparams_search` config is correct!"
        )

    metric_value = metric_dict[metric_name].item()
    log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")

    return metric_value


def close_loggers() -> None:
    """Makes sure all loggers closed properly (prevents logging failure during multirun)."""

    log.info("Closing loggers...")

    if find_spec("wandb"):  # if wandb is installed
        import wandb

        if wandb.run:
            log.info("Closing wandb!")
            wandb.finish()


def git_repo_root() -> str:
    """Returns the path to the root of the repository."""
    try:
        # Run the git command to get the top-level Git directory
        git_root = subprocess.check_output(
            ["git", "rev-parse", "--show-toplevel"], cwd=os.getcwd()
        ).strip()
        # Decode from bytes to string if necessary (Python 3.x)
        if isinstance(git_root, bytes):
            git_root = git_root.decode("utf-8")
        assert isinstance(git_root, str)
        if not git_root.endswith(os.sep):
            git_root += os.sep
        return git_root
    except subprocess.CalledProcessError:
        # Handle the case where the path is not within a Git repository
        raise Exception("The current directory is not within a Git repository.")


def save_predictions_to_csv(
    prediction_runs: dict[str, dict[str, torch.Tensor]], cfg: DictConfig
):
    saved_files = []
    for run_name, prediction in prediction_runs.items():
        for var_name, pred_tensor in prediction.items():
            assert (
                pred_tensor.shape[0] == 1
            ), f"Shape should be (1, len, dim), actually: {pred_tensor.shape[0]}"
            pred_np = pred_tensor.numpy()[0, :, :]

            # Get column naming from koelwijns datamodule config
            columns = cfg.model.estimated_variables[var_name]
            # Convert predictions to DataFrame and save as CSV
            pred_df = pd.DataFrame(pred_np, columns=columns)
            output_dir = Path(cfg.paths.output_dir)
            csv_path = os.path.join(
                output_dir, f"pred_{run_name}_{var_name}.parquet"
            )  # Assuming cfg has a log_dir attribute
            pred_df.to_parquet(csv_path, index=False)
            saved_files.append(csv_path)
    return saved_files


def really_safe_normalise_in_place(d: dict):
    factor = 1.0 / math.fsum(d.values())
    for k in d:
        d[k] = d[k] * factor
    key_for_max = max(d.items(), key=operator.itemgetter(1))[0]
    diff = 1.0 - math.fsum(d.values())
    # print "discrepancy = " + str(diff)
    d[key_for_max] += diff


def get_index_to_IK_column_name(column_name: str):
    assert column_name in IK_columns, f"{column_name} is not present in IK_data"

    return IK_columns.index(column_name)


def compare_derivatives(IK_data: torch.Tensor):
    def _idx_for_col_der(col: str):
        return [
            get_index_to_IK_column_name(col),
            get_index_to_IK_column_name("d" + col),
            get_index_to_IK_column_name("dd" + col),
        ]

    # take derivative of IK_data[0]
    idxs_to_compare_derivatives = [
        _idx_for_col_der(col)
        for col in [
            "tx",
            "ty",
            "a_pelvis",
            "a_hip_r",
            "a_knee_r",
            "a_ankle_r",
            "a_lumbar_r",
            "a_hip_l",
            "a_knee_l",
            "a_ankle_l",
            "a_lumbar_l",
        ]
    ]

    derivative_IK_data = torch.zeros_like(IK_data)
    for idxs in idxs_to_compare_derivatives:
        for i in range(3):
            derivative_IK_data[:, idxs[i]] = torch.autograd.grad(
                IK_data[:, idxs[i]],
                IK_data[:, get_index_to_IK_column_name("nSample")],
                create_graph=True,
                retain_graph=True,
            )[0]

    return derivative_IK_data
