import csv
import os
from abc import abstractmethod
from datetime import datetime
from os.path import join as pjoin

import torch
from collections import OrderedDict
from omegaconf import OmegaConf
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from torch.utils.data import DataLoader, DistributedSampler
from transformers import get_scheduler, logging as hf_logging

from common.callbacks import build_callbacks
from common.config import import_item
from common.entrypoint import Entrypoint


class BaseModel(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=["train_dataset", "eval_dataset"])
        self.eval_dataset = kwargs["eval_dataset"]
        self.train_dataset = kwargs["train_dataset"]

    def configure_optimizers(self):
        # Optimizer
        optim_config = dict(self.hparams.config.optimizer)
        optim_name = optim_config["name"]
        if "." in optim_name:
            import_path = ".".join(optim_name.split(".")[:-1])
            import_name = optim_name.split(".")[-1]
            optim_class = import_item(import_path, import_name)
        else:
            optim_class = getattr(torch.optim, optim_name)

        optim_config.pop("name")
        optimizer = optim_class(self.parameters(), **optim_config)

        # Scheduler
        lr_scheduler_config = dict(self.hparams.config.lr_schedule)
        lr_scheduler = get_scheduler(optimizer=optimizer, **lr_scheduler_config)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
        }

    @rank_zero_only
    def log_progress_bar(self, metrics_dict):
        self.trainer.progress_bar_metrics.update(metrics_dict)

    @rank_zero_only
    def save_metrics_csv(self):
        """
        Save all log_dict metrics to a CSV file.
        """
        if self.trainer.is_global_zero:
            metrics = self.trainer.logged_metrics
            if metrics:
                csv_path = os.path.join(self.hparams.output_dir, "metrics.csv")
                with open(csv_path, "w") as f:
                    writer = csv.writer(f)
                    writer.writerow(metrics.keys())
                    writer.writerows(zip(*metrics.values()))

    def on_test_epoch_end(self):
        """
        Perform post-epoch operations for the test phase.
        """
        if self.trainer.is_global_zero:
            # save all log_dict metrics to csv
            self.save_metrics_csv()


class HFMTrainer(Entrypoint):
    def entrypoint(self):
        config = self.config
        # if config.training.get("seed", False):
        #     seed_everything(config.training.seed)

        self.configure_output_dir()
        self.configure_callbacks()
        self.configure_loggers()
        self.configure_datasets()
        self.configure_model()

        if config.training.torch_compile:
            self.model = torch.compile(self.model, mode="reduce-overhead")

        try:
            torch.set_float32_matmul_precision("high")
        except Exception as e:
            self.logger.warning(f"Failed to set float32_matmul_precision: {e}")

        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        os.environ["NCCL_DEBUG"] = "WARN"
        hf_logging.set_verbosity_error()
        torch.multiprocessing.set_sharing_strategy("file_system")

        self.trainer = Trainer(
            default_root_dir=self.dir_params["logging_dir"],
            max_epochs=config.training.get("train_epochs", 1),
            max_steps=config.training.get("train_steps", -1),
            logger=self.loggers,
            callbacks=self.callbacks,
            check_val_every_n_epoch=config.training.get("eval_steps", 1),
            val_check_interval=config.training.get("eval_interval", 1.0),
            accelerator=config.get("accelerator", "auto"),
            devices=config.get("devices", "auto"),
            num_nodes=config.get("num_nodes", 1),
            strategy=config.get("strategy", "ddp"),
            precision=config.get("precision", None),
            log_every_n_steps=1,
            benchmark=False,
            deterministic=False,
            gradient_clip_val=config.training.get("gradient_clip", None),
            accumulate_grad_batches=config.training.get("accumulate_grad_batches", 1),
        )

    @abstractmethod
    def configure_datasets(self): ...

    @abstractmethod
    def configure_model(self): ...

    def configure_callbacks(self):
        self.callbacks = build_callbacks(
            self.config,
            output_dir=self.dir_params["output_dir"],
            logging_dir=self.dir_params["logging_dir"],
            logger=self.logger,
        )

    def configure_loggers(self):
        self.loggers = []
        for logger_name in self.config.logging.report_to:
            if logger_name == "tensorboard":
                tb_logger = TensorBoardLogger(
                    save_dir=self.dir_params["logging_dir"],
                    name="tensorboard",
                    version="",
                )
                self.loggers.append(tb_logger)
            elif logger_name == "wandb":
                wandb_logger = WandbLogger(
                    name=self.config.training.name,
                    project=self.config.training.project,
                    save_dir=self.dir_params["logging_dir"],
                    offline=True if self.config.get("debug", False) else False,
                    id=None,
                    version="",
                )
                self.loggers.append(wandb_logger)
            else:
                raise ValueError(f"Unknown logger: {logger_name}")

    def configure_output_dir(self):
        running_dir = self.config.get("running_dir", "runs")
        run_name = self.config.training.name
        logging_dir = pjoin(running_dir, run_name)
        datefmt = "%Y-%m-%d_%H:%M:%S"
        time_str = datetime.now().strftime(datefmt)
        if self.config.get("output_dir", None) is None:
            output_dir = pjoin(logging_dir, f"output_{time_str}")
        else:
            output_dir = self.config.output_dir
        self.dir_params = {
            "output_dir": output_dir,
            "logging_dir": pjoin(logging_dir, f"logging_{time_str}"),
        }
        if rank_zero_only.rank == 0:
            os.makedirs(pjoin(logging_dir, f"logging_{time_str}"), exist_ok=True)
            OmegaConf.save(
                config=self.config,
                f=pjoin(pjoin(logging_dir, f"logging_{time_str}"), f"{time_str}.yaml"),
            )

    def load_model(self):
        """
        Load a model from the given path.
        """
        if self.config.model.get("load_path", None) is None:
            return

        state_dict = torch.load(
            self.config.model.load_path,
            map_location="cpu",
            weights_only=False,
        )["state_dict"]
        init_state_dict = self.model.state_dict()
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if "pad_buffer" in k:
                new_state_dict[k] = init_state_dict[k]
            else:
                new_state_dict[k] = v
        print(f"Loading model from {self.config.model.load_path}")

        self.model.load_state_dict(new_state_dict)

    def fit(self):
        dataload_params = {
            "collate_fn": self.data_collator,
            "num_workers": self.config.data.get("num_workers", 4),
            "prefetch_factor": self.config.data.get("prefetch_factor", 2),
            "persistent_workers": self.config.data.get("persistent_workers", True),
            "pin_memory": self.config.data.get("pin_memory", True),
        }

        if self.config.training.get("resume_path", None):
            ckpt_path = os.path.join(self.config.training.resume_path, "last.ckpt")
        else:
            ckpt_path = None

        self.trainer.fit(
            self.model,
            DataLoader(
                self.train_dataset,
                batch_size=self.config.training.train_batch_size,
                shuffle=True,
                drop_last=False,
                **dataload_params,
            ),
            DataLoader(
                self.eval_dataset,
                batch_size=self.config.training.eval_batch_size,
                shuffle=False,
                **dataload_params,
            ),
            ckpt_path=ckpt_path,
        )

    def test(self):
        dataload_params = {
            "collate_fn": self.data_collator,
            "num_workers": self.config.data.get("num_workers", 8),
            "prefetch_factor": self.config.data.get("prefetch_factor", 2),
        }
        self.trainer.test(
            self.model,
            dataloaders=DataLoader(
                self.eval_dataset,
                batch_size=self.config.training.eval_batch_size,
                shuffle=False,
                **dataload_params,
            ),
        )
