import os
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Literal, Mapping, Optional, Set

import torch
import torch._dynamo.config
from accelerate import DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from tqdm import tqdm

from dae.utils.generic_utils import (
    ModulesRegister,
    TaskState,
    UpAccelerator,
    ensure_path,
    split_dict,
)
from dae.utils.torch_utils import count_parameters, to_torch_type, unwrap
from dae.utils.train_utils import (
    aggregate_losses,
    auto_compile,
    build_optimizer,
)
from dae.utils.train_utils import init_weights as mutils_init_weights

from ..datasets import dataset_from_name
from ..log.loggers import BaseTaskLogger, MetricLogger
from ..models.all import *  # pylint: disable=unused-wildcard-import, wildcard-import
from ..models.blocks.ema import EMA, EMAWrapper
from ..utils import load_checkpoint, load_training_state, save_training_state

####################################################################
# Tasks register
####################################################################


class TasksRegister(ModulesRegister):
    """Registers each method named task_<task_name> of the module as <module_name>.<task_name>"""

    def __setitem__(self, module_name, module):
        for method in dir(module):
            if method.startswith("task_"):
                task_name = method[len("task_") :]
                task_name = f"{module_name}.{task_name}"
                if self.lower:
                    task_name = task_name.lower()
                self._modules[task_name] = module


TASKS = TasksRegister("tasks")

####################################################################
# Common initialization
####################################################################


class BaseTask:
    def __init__(self, cfg):
        self.state = TaskState(cfg=cfg)

        self.setup_job_env()
        self.setup()

    @property
    def cfg(self):
        return self.state.cfg

    @property
    def accelerator(self):
        return self.state.accelerator

    @property
    def models(self):
        return self.state.models

    @property
    def optimizers(self):
        return self.state.optimizers

    @property
    def logger(self):
        return self.state.logger

    def setup_job_env(self):
        """Setup the job environment"""
        # Ensure working inside the run directory
        ensure_path(self.cfg.run_dir)
        ensure_path(self.cfg.cache_dir)
        ensure_path(self.cfg.checkpoint_path)
        os.chdir(self.cfg.run_dir)

        # Set environment variables from config
        for var_name, var_value in self.cfg.env.items():
            os.environ[var_name] = str(var_value)

        # Set seed
        set_seed(self.cfg.seed)

        # Set torch configuration
        torch.set_default_dtype(to_torch_type(self.cfg.dtype))
        torch._dynamo.config.cache_size_limit = 64
        torch._dynamo.config.optimize_ddp = False
        torch.set_float32_matmul_precision("high")
        # torch.multiprocessing.set_start_method("forkserver", force=True)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.deterministic = self.cfg.deterministic

        # Anomaly detection
        assert self.cfg.detect_errors in [True, False]
        self.state.detect_errors = self.cfg.detect_errors
        torch.autograd.set_detect_anomaly(self.state.detect_errors)

    def setup(self):
        """Setup the job modules"""

        ### Prepare accelerator and DDP ###
        ddp_kwargs = DistributedDataParallelKwargs()
        # init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))

        self.state.accelerator = UpAccelerator(
            kwargs_handlers=[
                ddp_kwargs,
                # init_kwargs,
            ],
            gradient_accumulation_steps=self.cfg.training.grad_accumulate,
            step_scheduler_with_optimizer=False,
            mixed_precision=(self.cfg.mixed_precision or "no"),
            verbose=self.cfg.verbose,
        )

        # Display job information
        self.print("Detecting errosr in torch:", self.state.detect_errors)
        self.print("Will use DDP")
        self.init_state()

        # Logger
        self.state.logger = self.build_task_logger()

        # Accelerate configuration
        get_logger("accelerate.accelerator").setLevel("WARNING")
        get_logger("accelerate.checkpointing").setLevel("WARNING")

    def init_state(self):
        """Initialize the configuration & state with job-specific settings"""
        self.state.num_processes = self.accelerator.num_processes
        self.accelerator.register_for_checkpointing(self.state)

    def build_task_logger(self):
        return BaseTaskLogger(self.state)

    def prepare_model(
        self,
        model: torch,
        *,
        name: Optional[str] = None,
        training: Optional[bool] = None,
        checkpoint: Optional[str] = None,
        compile: bool = False,
        ema: Optional[Mapping] = None,
        freeze: Literal["auto", True, False] = "auto",
        model_init: Optional[Mapping] = None,
        remove_from_checkpointing: bool = False,
    ):
        """Utility to load a model, freeze it if not main model, etc."""
        # Args
        assert training in [True, False, None]
        assert freeze in ["auto", True, False]
        if training is None:
            training = self.training
        if freeze == "auto":
            freeze = not training

        if name is None:
            name = model.__class__.__name__
        assert name not in self.models, f"Model with name {name} already exists"

        # 1. Init weights (random init or checkpoint)
        mutils_init_weights(model, **(model_init or {}))

        # 2. Wrap model
        use_ema = False
        if ema is not None and ema.decay and training:
            use_ema = True
            model = EMAWrapper(model, **ema)

        # 3. Load checkpoint
        if checkpoint is not None:
            load_checkpoint(self.accelerator, model, checkpoint, log_dir=self.cfg.log_dir, model_name=name)

        # 4. Set model & gradients state
        model.train(training)
        if freeze:
            model.requires_grad_(False)

        # 5. Wrap with accelerator
        prep_modules = [model] if not use_ema else [model.model, model.ema.ema_model]
        for i, m in enumerate(prep_modules):
            m = self.accelerator.prepare(m)
            if remove_from_checkpointing:
                self.accelerator._models.remove(m)
            prep_modules[i] = m

        if not remove_from_checkpointing:
            self.state.registered_models.append(name)
        if use_ema:
            model.model = prep_modules[0]
            model.ema.ema_model = prep_modules[1]
            if not remove_from_checkpointing:
                self.state.registered_models.append(name + "_ema")
        else:
            model = prep_modules[0]

        # 6. Compile if needed
        model = auto_compile(compile, model)

        self.models[name] = model
        return model

    def run(self, task_name=None):
        if task_name is None:
            task_name = self.cfg.task.split(".")[-1]

        method_name = f"task_{task_name}"
        if hasattr(self, method_name):
            return self.__getattribute__(method_name)()
        else:
            raise ValueError(f"Run function {method_name} for task {self.cfg.task} inside {self.__class__.__name__} not found")

    def __call__(self, task_name=None):
        # Run task
        with self.logger.on_task_run() as task_log:
            set_seed(self.cfg.seed)
            task_result = self.run(task_name)

            # End task
            task_log.results = task_result
        self.accelerator.end_training()

        return task_result

    def print(self, *args, **kwargs):
        self.accelerator.print(*args, **kwargs)


####################################################################
# Generic tasks with commonly-used routines and pipelines
####################################################################


class BaseTrainEvalTask(BaseTask):
    SHOW_MODEL_PARTS = []

    def __init__(self, cfg):
        super().__init__(cfg)

        self.load_data()
        self.load_models()
        self.show_model()

        if self.training:
            self._task_train_prepare()
            # Load training state once model & optimizer are ready
            load_training_state(self.state, [self.cfg.checkpoint_path, self.cfg.training.resume_from])

        self.print("Memory summary:\n" + torch.cuda.memory_summary("cuda"))

    def load_models(self):
        raise NotImplementedError

    def _compute_train_loss(self, batch: Any, models: Set[str], train_ctx: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def _compute_gan_loss(self, batch: Any, fwd_output: Any) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    @contextmanager
    def save_model_on_error(self, step="unknown", enable=True):
        try:
            yield
        except Exception as e:
            self.print(f"Error during step {step}: {e}", level="error")
            if enable and self.accelerator.is_main_process:
                path = self.checkpoint_path / "error"
                self.print(f"Saving model checkpoint to {path} for debugging error", level="error")
                save_training_state(self.state, path)
            raise e

    def _show_model_subparameter_count(self, model, recursive_on=None, depth=-1, name=None):
        if depth == -1:
            model = unwrap(model, unw_ema=True)
            name = name or "model"
            self.print(f"{name} parameters count:")
            self.print(f"Total: #{count_parameters(model)}   (trainable: #{count_parameters(model, trainable=True)})")
            self._show_model_subparameter_count(model, recursive_on, depth=0, name=name)
        else:
            for name, m in unwrap(model).named_children():
                self.print("     " * depth + f"- {name}: #{count_parameters(m)}   (trainable: #{count_parameters(m, trainable=True)})")
                if recursive_on and name in recursive_on:
                    sub_rec = [n[len(name) + 1 :] if n.startswith(name + ".") else n for n in recursive_on]
                    self._show_model_subparameter_count(m, sub_rec, depth=depth + 1, name=name)

    def show_model(self):
        for m_name, m in self.models.items():
            if self.SHOW_MODEL_PARTS not in [None, False]:
                self._show_model_subparameter_count(m, recursive_on=self.SHOW_MODEL_PARTS, name=m_name)
            self.print(f"{m_name}:", m)

    def init_state(self):
        super().init_state()
        acc = self.accelerator
        cfg = self.cfg

        # Initialize registers
        self.state.cur_epoch = 0
        self.state.cur_steps = 0

        self.opti_models = []  # pylint: disable=W0201
        self.training = cfg.task.endswith("train")  # pylint: disable=W0201
        self.checkpoint_path = Path(cfg.checkpoint_path)  # pylint: disable=W0201

        # Compute effective batch size
        bs_factor = acc.num_processes * cfg.training.grad_accumulate
        if not cfg.training.batch_size and not cfg.training.gpu_batch_size:
            raise ValueError("Either batch_size or gpu_batch_size must be set")
        if cfg.training.batch_size and cfg.training.gpu_batch_size:
            if cfg.training.batch_size != bs_factor * cfg.training.gpu_batch_size:
                raise ValueError(
                    f"Total batch size is set to {cfg.training.batch_size} but should be {bs_factor * cfg.training.gpu_batch_size}, as GPU batch size is set to {cfg.training.gpu_batch_size} with {acc.num_processes} processes and {cfg.training.grad_accumulate} gradient accumulation steps"
                )
        elif cfg.training.batch_size:
            cfg.training.gpu_batch_size = cfg.training.batch_size // bs_factor
            assert cfg.training.gpu_batch_size * bs_factor == cfg.training.batch_size, f"Batch size ({cfg.training.gpu_batch_size}) must be divisible by the number of batch splits ({bs_factor})"
        else:
            cfg.training.batch_size = bs_factor * cfg.training.gpu_batch_size

        cfg.testing.gpu_batch_size = cfg.testing.gpu_batch_size or cfg.training.gpu_batch_size

    def build_task_logger(self):
        return MetricLogger(self.state, train=self.training)

    def load_data(self):
        (train_dataset, test_dataset), (self.train_loader, self.test_loader) = dataset_from_name(self.cfg)
        self._train_dataset = train_dataset
        self._test_dataset = test_dataset
        self.train_loader = self.accelerator.prepare(self.train_loader)
        self.test_loader = self.accelerator.prepare_test_data(self.test_loader)

        self.print(f"Loaded datasets (train={self.cfg.dataset.name}, test={self.cfg.test_dataset.name}):", {"train": train_dataset, "test": test_dataset})

    def _build_model(self, registry: ModulesRegister, **kwargs):
        prep_args = ["checkpoint", "name", "compile", "ema", "freeze", "model_init", "remove_from_checkpointing"]
        prep_kwargs, module_kwargs = split_dict(kwargs, prep_args)
        model = registry.build(module_kwargs)
        return self.prepare_model(model, **prep_kwargs)

    def _task_train_prepare(self):
        self.state._optimizer_cfgs = list(self.cfg.training.optimizers.values())
        self.state._optimizer_cfgs.sort(key=lambda x: x.index)

        # Build optimizers
        for opt_cfg in self.state._optimizer_cfgs:
            if opt_cfg is not None:
                # Build optimizer
                opt_models = {m_name: m for m_name, m in self.models.items() if m_name in opt_cfg.models}

                optimizer = build_optimizer(opt_models, opt_cfg.name, opt_cfg.lr, opt_cfg.get("args", {}))
                optimizer.eval()

                # Log optimizers
                model_descs = ", ".join([f"#{name}={count_parameters(model)}" for name, model in opt_models.items()])
                self.print(f"Optimizer for {model_descs}: {optimizer}")

                # Prepare with accelerator and store
                optimizer = self.accelerator.prepare(optimizer)
                self.optimizers.append(optimizer)
                self.opti_models.append(opt_models)

    def _train_end_step_trigger(self, i_batch):
        if self.cfg.training.save_on_log and i_batch % self.cfg.training.log_freq == 0 and self.accelerator.is_main_process:  # For debug only
            path = self.checkpoint_path / "debug_logstep"
            self.print(f"Saving model checkpoint to {path} for debugging")
            save_training_state(self.state, path)

    def _train_do_step(self, optimizer: torch.optim.Optimizer, models: Dict[str, torch.nn.Module], batch: Any, train_ctx: Dict[str, Any]):
        acc = self.accelerator

        with acc.autocast(), self.logger.part_timers.train_forward:  # pylint: disable=no-member
            losses = self._compute_train_loss(batch, set(models.keys()), train_ctx)
        assert isinstance(losses, dict) and all(isinstance(v, torch.Tensor) for v in losses.values()), f"Losses should be a dict of tensors, got {losses}"
        assert len(losses) > 0, f"No losses returned by _compute_train_loss with models {list(models.keys())}"

        with self.save_model_on_error(step="training losses"):
            sum_loss, losses = aggregate_losses(self.cfg, losses)

        with self.save_model_on_error(step="backward"):
            with self.logger.part_timers.train_backward:  # pylint: disable=no-member
                acc.backward(sum_loss)

        if self.state.detect_errors:
            for m in models.values():
                for name, param in m.named_parameters():
                    if param.grad is not None:
                        grad_strides = param.grad.stride()
                        param_strides = param.stride()
                        if grad_strides != param_strides:
                            print(f"[WARNING] Stride mismatch in param: {name}")
                            print(f"  Param strides: {param_strides}")
                            print(f"  Grad strides:  {grad_strides}")

        with self.save_model_on_error(step="optimizers step"):
            if self.cfg.training.grad_clip and acc.sync_gradients:
                for m in models.values():
                    acc.clip_grad_norm_(m.parameters(), self.cfg.training.grad_clip)
            with self.logger.part_timers.optimizer_step:  # pylint: disable=no-member
                optimizer.step()
            optimizer.zero_grad(set_to_none=True)  # Keep here to ensure grad accumulation works correctly

        return {k: v.detach() for k, v in losses.items()}

    def train(self, is_training=True):
        for m in self.models.values():
            m.train(is_training)

        for opt in self.optimizers:
            if is_training:
                opt.train()
            else:
                opt.eval()

    def _task_train_one_epoch(self):
        """Generic training loop. Can be overwritten for more specific tasks. If used, you need to define _compute_train_loss"""
        acc = self.accelerator
        assert self.training, "Not in training mode"
        self.train()

        with self.logger.on_epoch(self.train_loader):
            timed_train_loader = self.logger.part_timers.load_train_data.timed_iter(self.train_loader)  # pylint: disable=no-member

            for i_batch, batch in enumerate(timed_train_loader):
                # Start step
                with self.logger.on_batch(i_batch) as batch_log:
                    train_ctx = {
                        "losses": batch_log.losses,
                        "i_batch": i_batch,
                        "cur_epoch": self.state.cur_epoch,
                        "cur_steps": self.state.cur_steps,
                    }

                    for optimizer, opt_models in zip(self.optimizers, self.opti_models):
                        # Do steps
                        with acc.accumulate(*opt_models.values()):
                            batch_log.losses.update(self._train_do_step(optimizer, opt_models, batch, train_ctx=train_ctx))

                        # EMA
                        for m in opt_models.values():
                            EMA.update_ema_modules(m)

                    # End step
                    self._train_end_step_trigger(i_batch)
                acc.wait_for_everyone()
                self.state.cur_steps += 1
                # self.print("Memory summary:\n" + torch.cuda.memory_summary("cuda"))

    def _task_train_post_eval(self, did_eval):
        # Storing last checkpoint
        ckpt = self.checkpoint_path / "last"
        self.accelerator.print(f"Storing model checkpoint inside {ckpt}")
        save_training_state(self.state, ckpt)

        elapsed_epochs = self.state.cur_epoch + 1
        if self.cfg.training.save_every_epoch and elapsed_epochs % self.cfg.training.save_every_epoch == 0:
            ckpt = self.checkpoint_path / f"epoch_{elapsed_epochs}"
            self.accelerator.print(f"Storing a copy of the model checkpoint to {ckpt}")
            save_training_state(self.state, ckpt)

        if did_eval and self.logger.epochs_since_best_score() == 0:
            ckpt = self.checkpoint_path / "best"
            self.accelerator.print(f"Best {self.cfg.training.save_on_best} so far, storing a copy of the model checkpoint to {ckpt}")
            save_training_state(self.state, ckpt)

    def task_train(self):
        # Start directly at state.cur_epoch (even if > 0)
        did_eval_last_epoch = False
        while not self.should_stop_training(did_eval_last_epoch):
            self._task_train_one_epoch()

            eval_now = (self.state.cur_epoch + 1) % self.cfg.training.eval_freq == 0
            if eval_now:
                did_eval_last_epoch = True
                self.task_eval()

            self.state.cur_epoch += 1
            self._task_train_post_eval(did_eval_last_epoch)

        # Ensure last eval
        if not did_eval_last_epoch:
            self.accelerator.print("Training stopped, final evaluation")
            self.task_eval()
            self._task_train_post_eval(did_eval_last_epoch)

    def task_eval(self):
        raise NotImplementedError

    def should_stop_training(self, did_eval):
        if did_eval and self.cfg.training.stop_patience and self.logger.epochs_since_best_score() >= self.cfg.training.stop_patience:
            self.accelerator.print(f"{self.cfg.training.save_on_best} did not improve for {self.cfg.training.stop_patience} epochs, stopping")
            return True
        if self.cfg.training.epochs and self.state.cur_epoch >= self.cfg.training.epochs:
            self.accelerator.print(f"Reached maximum epochs, stopping (epoch {self.state.cur_epoch} / {self.cfg.training.epochs})")
            return True
        return False


class BaseAutoencodingTask(BaseTrainEvalTask):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__(cfg, *args, **kwargs)

    def _generate_for_eval(self, x, y=None, generator=None):
        """Generate (if x is None) or reconstruct (if x is set) a sample. Might be conditioned on y."""
        raise NotImplementedError("Not implemented yet")

    @torch.no_grad()
    def task_eval(self):
        acc = self.accelerator
        self.train(False)
        self.generator = torch.Generator(device=acc.device)
        self.generator.manual_seed(self.cfg.seed)

        tqdm_dis = not acc.is_main_process or not self.cfg.verbose

        with self.logger.on_eval(self.test_loader) as eval_log:
            # Eval reconstruction of test set
            enum_tests = tqdm(self.test_loader, desc="Reconstructing from test set", disable=tqdm_dis)
            for test_samples, label in enum_tests:
                test_samples = test_samples.to(acc.device)
                with acc.autocast(), self.logger.eval_part_timers.model:  # pylint: disable=no-member
                    rec_samples = self._generate_for_eval(test_samples, generator=self.generator)

                with self.logger.eval_part_timers.metrics:  # pylint: disable=no-member
                    rgb_test_samples = self.to_rgb(test_samples)
                    rgb_rec_samples = self.to_rgb(rec_samples)

                    self.logger.metrics.update(
                        x_gt=rgb_test_samples,
                        x_pred=rgb_rec_samples,
                        y_gt=label,
                    )

            # Generate displayed samples
            if self.cfg.logging.samples.n and acc.is_main_process:
                n_samples = self.cfg.logging.samples.n
                eval_log.gt_samples = rgb_test_samples[:n_samples]
                eval_log.rec_samples = rgb_rec_samples[:n_samples]

        return deepcopy(self.logger.metrics.last_m_vals)

    ##### Utils #####

    def to_rgb(self, x):  # x should be in normalized space ; output will be in [0;1]
        x = x * self.cfg.dataset.normalize.std + self.cfg.dataset.normalize.mean
        x = torch.clamp(255 * x, 0, 255).round() / 255

        assert x.ndim == 4
        if x.shape[1] == 1:  # Go from grayscale to RGB
            x = x.repeat(1, 3, 1, 1)
        return x
