import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence

import numpy as np
import torch
from konductor.init import ExperimentInitConfig
from konductor.trainer.pytorch import PyTorchTrainer, TrainingError, RunningMean
from konductor.utilities import comm
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.profiler import record_function

from src.dataset.sc2_dataset import TorchSC2Data
from src.model import MotionPerceiver


class TrainStats:
    """Statistics for model training throughput"""

    def __init__(self, target: int, skip_first: bool = True):
        self.target = target
        self._step_time = RunningMean()
        self._gpu_mem_avg = RunningMean()
        self._gpu_mem_max = 0
        self._is_first = skip_first

    def log(self, step_time: float):
        """Log the time it takes for a step and the memroy used"""
        if self._is_first:
            print("skipping first log")
            self._is_first = False
            return

        self._step_time.update(step_time)
        current_mem = torch.cuda.memory_allocated() / 1e9
        self._gpu_mem_avg.update(current_mem)
        self._gpu_mem_max = max(self._gpu_mem_max, current_mem)

    @property
    def count(self):
        """Number of samples"""
        return self._step_time.count

    @property
    def step_time(self):
        return self._step_time.value

    @property
    def gpu_mem_avg(self):
        return self._gpu_mem_avg.value

    @property
    def gpu_mem_max(self):
        return self._gpu_mem_max

    def __bool__(self):
        return self.target > 0

    def finished(self):
        """Number of samples greater or equal to target"""
        return self.count >= self.target

    def reset(self):
        """Reset logging"""
        self._step_time.reset()
        self._gpu_mem_avg.reset()
        self._gpu_mem_max = 0


@dataclass
class DaliPipeParams:
    py_workers: int
    source_prefetch: int
    pipe_prefetch: int


class SC2Trainer(PyTorchTrainer):
    def __init__(self, *args, collect_stats: int = 0, **kwargs):
        super().__init__(*args, **kwargs)
        self.collect_stats = TrainStats(collect_stats)

    def data_transform(self, data: list[dict[str, Tensor]]) -> Any:
        """Convert dictionary loaded by DALI into TorchSC2Data with sequence dimension out front
        (except minimap)"""
        return TorchSC2Data.from_dali(data)

    def train_step(
        self, data: dict[str, Tensor]
    ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None]:
        if not self.collect_stats:
            return super().train_step(data)

        torch.cuda.synchronize()
        s_time = time.perf_counter()
        ret = super().train_step(data)
        torch.cuda.synchronize()
        self.collect_stats.log(time.perf_counter() - s_time)
        if self.collect_stats.finished():
            raise StopIteration  # Jump out of loop

        return ret

    def val_step(self, data) -> tuple[dict[str, Tensor] | None, dict[str, Tensor]]:
        if not self.collect_stats:
            return super().val_step(data)

        torch.cuda.synchronize()
        s_time = time.perf_counter()
        ret = super().val_step(data)
        torch.cuda.synchronize()
        self.collect_stats.log(time.perf_counter() - s_time)
        if self.collect_stats.finished():
            raise StopIteration  # Jump out of loop

        return ret


class GymTrainer(PyTorchTrainer):
    def __init__(self, *args, collect_stats: int = 0, **kwargs):
        super().__init__(*args, **kwargs)
        self.collect_stats = TrainStats(collect_stats)

    @torch.no_grad()
    def trace_model_io(self, model: Any, data: dict[str, Tensor]) -> None:
        """Grab model intermediate outputs, save to disk, and send to remote for analysis"""
        rnk = comm.get_rank()

        def input_save_hook(_, inputs: list[Tensor], outputs: list[Tensor]) -> None:
            idx = len(list(self.data_manager.workspace.glob(f"*_input_{rnk}.npz")))
            save_data = {"latent_out": outputs[0].cpu().numpy()}
            for name, data in zip(["latent_in", "data", "mask"], inputs):
                save_data[name] = data.cpu().numpy()
            np.savez_compressed(
                self.data_manager.workspace / f"{idx}_input_{rnk}", **save_data
            )

        def prop_save_hook(_, inputs: list[Tensor], outputs: list[Tensor]) -> None:
            idx = len(list(self.data_manager.workspace.glob(f"*_propagate_{rnk}.npz")))
            np.savez_compressed(
                self.data_manager.workspace / f"{idx}_propagate_{rnk}",
                in_latent=inputs[0].cpu().numpy(),
                out_latent=outputs[0].cpu().numpy(),
            )

        def update_save_hook(_, inputs: list[Tensor], outputs: list[Tensor]) -> None:
            idx = len(list(self.data_manager.workspace.glob(f"*_update_{rnk}.npz")))
            save_data = {"latent_out": outputs[0].cpu().numpy()}
            for name, data in zip(["latent_in", "data", "mask"], inputs):
                save_data[name] = data.cpu().numpy()
            np.savez_compressed(
                self.data_manager.workspace / f"{idx}_update_{rnk}", **save_data
            )

        model_: MotionPerceiver = (
            model.module if isinstance(model, DistributedDataParallel) else model
        )

        model_.encoder.input_layer.register_forward_hook(input_save_hook)
        model_.encoder.propagate_layer.register_forward_hook(prop_save_hook)
        model_.encoder.update_layer.register_forward_hook(update_save_hook)

        model(**data)

    def data_transform(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            data = {k: d.cuda(non_blocking=True) for k, d in data.items()}
        stream.synchronize()
        return data

    def train_step(
        self, data: dict[str, Tensor]
    ) -> tuple[dict[str, Tensor], dict[str, Tensor] | None]:
        """
        Standard training step, if you don't want to calculate
        performance during training, return None for predictions.
        return
            Losses: description of losses for logging purposes
            Predictions: predictions in dict
        """
        if self.collect_stats:
            torch.cuda.synchronize()
            s_time = time.perf_counter()

        with record_function("train_inference"):
            pred = self.modules.model(**data)

        with record_function("criterion"):
            losses = {}
            for criterion in self.modules.criterion:
                losses.update(criterion(pred, data))

        if self.collect_stats:
            torch.cuda.synchronize()
            self.collect_stats.log(time.perf_counter() - s_time)
            if self.collect_stats.finished():
                raise StopIteration  # Jump out of loop

        # if any(not torch.isfinite(l) for l in losses.values()):
        #     self.trace_model_io(self.modules.model, data)
        #     raise TrainingError(f"nan loss found: {losses}")

        return losses, pred

    def val_step(
        self, data: dict[str, Tensor]
    ) -> tuple[dict[str, Tensor] | None, dict[str, Tensor]]:
        """
        Standard evaluation step, if you don't want to evaluate/track loss
        during evaluation, do not perform the calculation and return None
        in the loss part of the tuple.
        return:
            Losses: description of losses for logging purposes
            Predictions: predictions dict
        """
        if self.collect_stats:
            torch.cuda.synchronize()
            s_time = time.perf_counter()

        with record_function("eval_inference"):
            pred = self.modules.model(**data)

        if self.collect_stats:
            torch.cuda.synchronize()
            self.collect_stats.log(time.perf_counter() - s_time)
            if self.collect_stats.finished():
                raise StopIteration  # Jump out of loop

        return None, pred

    def training_exception(self, err: Exception, data: dict[str, Tensor]) -> None:
        rnk = comm.get_rank()
        is_main_rnk = rnk == 0
        np.savez_compressed(
            self.data_manager.workspace / f"input_{rnk}",
            **{k: v.cpu().numpy() for k, v in data.items()},
        )

        if is_main_rnk:  # Only main rank saves checkpoint
            self.data_manager.checkpointer.save("exception_ckpt.pt")

        if self.data_manager.remote_sync is not None:
            comm.synchronize()  # Ensure data saving finished before pushing
            if is_main_rnk:  # Only main rank pushes checkpoint
                self.data_manager.remote_sync.push_select([r"\Aexception_ckpt.pt\Z"])
            if comm.get_local_rank() == 0:  # Only local rank 0 pushes data
                self.data_manager.remote_sync.push_select([r".*\.npz\Z"])

        comm.synchronize()  # Sync before raise
        raise err


def add_backward_monitor(trainer: PyTorchTrainer):
    """Monitors the backward method of the decoder as that's where the nan is likely to begin"""

    def backward_hook(
        module, grad_inputs: Sequence[Tensor], grad_outputs: Sequence[Tensor]
    ):
        should_raise = False
        if not all(torch.isfinite(g).all() for g in grad_inputs):
            should_raise = True
            print("nan detected in grad input of decoder")

        if not all(torch.isfinite(g).all() for g in grad_outputs):
            should_raise = True
            print("nan detected in grad outputs of decoder")

        should_raise = any(comm.all_gather(should_raise))  # Check if any rank raises
        if not should_raise:
            return

        write_data: dict[str, np.ndarray] = {}
        for idx, grad in enumerate(grad_inputs):
            write_data[f"input_{idx}"] = grad.cpu().numpy()
        for idx, grad in enumerate(grad_outputs):
            write_data[f"output_{idx}"] = grad.cpu().numpy()

        np.savez_compressed(
            trainer.data_manager.workspace / f"decoder_grads_{comm.get_rank()}",
            **write_data,
        )

        raise TrainingError("Nan gradient in decoder")

    trainer.modules.get_model().decoder.register_full_backward_hook(backward_hook)


def apply_dali_pipe_kwargs(exp_cfg: ExperimentInitConfig, dali_params: DaliPipeParams):
    """Add dali specific pipeline args"""
    exp_cfg.data[0].dataset.args["prefetch_queue_depth"] = dali_params.source_prefetch
    for loader in [exp_cfg.data[0].train_loader, exp_cfg.data[0].val_loader]:
        loader.args["py_num_workers"] = dali_params.py_workers
        loader.args["prefetch_queue_depth"] = dali_params.pipe_prefetch


def setup_init_config(
    workspace: Path,
    config_file: Path | None,
    run_hash: str | None,
    workers: int,
    dali_params: DaliPipeParams,
):
    """Do basic setup of ExperimentInitConfig"""
    if config_file is not None:
        assert run_hash is None, "config-file or run-hash should be exclusively be set"
        exp_cfg = ExperimentInitConfig.from_config(workspace, config_file)
    elif run_hash is not None:
        exp_cfg = ExperimentInitConfig.from_run(workspace / run_hash)
    else:
        raise RuntimeError("Either config-file or run-hash should be set")
    exp_cfg.set_workers(workers)

    if exp_cfg.data[0].train_loader.type == "DALI":
        apply_dali_pipe_kwargs(exp_cfg, dali_params)

    return exp_cfg
