import glob
import logging
import os
import time
from dataclasses import dataclass

import torch
from misc_utils import time_formatter
from model.vae import clip_loss, vae_loss, VAEConfig
from omegaconf import DictConfig, ListConfig, MISSING, OmegaConf

from optimizer.optimizer_lib import OptimizerConfig
from scheduler.scheduler_lib import SchedulerConfig
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard.writer import SummaryWriter
from torchmetrics import MeanMetric
from tqdm import tqdm

from .checkpoint import CheckpointManager
from .collate import EmbTrainCollate, TTbatch
from .logger import AutoLogger

DTYPE_MAP = {
    "float16": torch.float16,
    "float32": torch.float32,
    "float64": torch.float64,
    "bfloat16": torch.bfloat16,
}


@dataclass
class EmbTrainConfig:
    model: VAEConfig = MISSING
    optimizer: OptimizerConfig = MISSING
    scheduler: SchedulerConfig = MISSING
    epochs: int = 100
    batch_size: int = 128
    step_logging: int = 1000
    data_path: str = "data/random"
    save_every: int = 1
    use_amp: bool = True
    dtype: str = "float32"
    grad_norm_clip: float = 1.0
    device: str = "cuda"
    device_id: int = 0
    dl_workers: int = 4

    # logging
    logger_type: str = "wandb"
    log_dir: str = "logs"
    mlflow_uri: str = "http://127.0.0.1:5000"
    project_name: str = "ShortCircuit"
    experiment_name: str = "Truth Table Embeddings [8-input]"
    run_id: str | None = None
    checkpoint_dir: str | None = None


class EmbeddingTrainer:
    def __init__(
        self,
        config: EmbTrainConfig,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
        train_dataset: Dataset,
        eval_dataset: Dataset,
    ):
        self.config = config
        self.dtype = DTYPE_MAP[config.dtype]
        self.device = torch.device(f"{config.device}:{config.device_id}")
        self.model = model.to(self.dtype)
        self.start_time: float = time.time()

        # parameters

        self.n_total_epochs = self.config.epochs
        self.n_seen_points = 0
        self.epochs_run = 0
        self.step_logging = self.config.step_logging
        self.log_dir = self.config.log_dir
        self.batch_size = self.config.batch_size

        # logging
        self.console_logger = logging.getLogger(__name__)
        self.exp_logger = AutoLogger(self.config.logger_type, config=self.config)
        self.vae_loss = MeanMetric().to(self.config.device)
        self.clp_loss = MeanMetric().to(self.config.device)
        self.accuracy = MeanMetric().to(self.config.device)

        # Define metrics
        metrics_def = {
            # Epoch-based metrics
            "train/VAE_Loss": "epoch",
            "train/CLIP_Loss": "epoch",
            "train/Accuracy": "epoch",
            "train/Total_Loss": "epoch",
            "eval/VAE_Loss": "epoch",
            "eval/CLIP_Loss": "epoch",
            "eval/Accuracy": "epoch",
            "eval/Total_Loss": "epoch",
            # Total samples-based metrics (with x-total suffix)
            "train/VAE_Loss x-total": "x-total",
            "train/CLIP_Loss x-total": "x-total",
            "train/Total_Loss x-total": "x-total",
        }
        self.exp_logger.define_metrics(metrics_def)

        os.environ["HTTP_PROXY"] = ""

        # dataset
        self.train_loader = self._prepare_dataloader(train_dataset, train=True)
        self.eval_loader = self._prepare_dataloader(eval_dataset, train=False)

        self._grads = None
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        checkpoint_dir = os.path.join(self.exp_logger.get_run_dir(), "checkpoints")
        self.checkpoint_manager = CheckpointManager(
            save_dir=checkpoint_dir,
            model=self.model,
            optimizer=self.optimizer,
            lr_scheduler=self.lr_scheduler,
            logger=self.exp_logger,
            save_freq=self.config.save_every,
            metric_name="eval/Accuracy",
            mode="max",
        )

        # load model from checkpoint
        if self.config.checkpoint_dir is not None and os.path.exists(
            self.config.checkpoint_dir
        ):
            checkpoint = self.config.checkpoint_dir
            if not checkpoint.endswith("checkpoints"):
                checkpoint = os.path.join(checkpoint, "checkpoints")
            checkpoint_data = self.checkpoint_manager.load_from_directory(
                checkpoint, "latest"
            )
            if checkpoint_data:
                snapshot, _ = checkpoint_data
                self.epochs_run = self.checkpoint_manager.restore(snapshot)
                self.n_seen_points = self.epochs_run * len(self.train_loader)
                self.console_logger.info(f"Resumed from epoch {self.epochs_run}")

        self._log_params()

    def _model_loss(self, batch: TTbatch) -> tuple[torch.Tensor, torch.Tensor]:
        # forward pass
        x1_hat, mu1, logvar1 = self.model(batch.norm_float)
        x2_hat, mu2, logvar2 = self.model(batch.neg_float)

        vae_loss1 = vae_loss(x1_hat, batch.norm_float, mu1, logvar1)
        vae_loss2 = vae_loss(x2_hat, batch.neg_float, mu2, logvar2)

        tot_vae_loss = (vae_loss1 + vae_loss2) / (2 * batch.norm_float.size(0))
        self.vae_loss.update(tot_vae_loss)
        clp_loss = clip_loss(mu1, mu2)
        self.clp_loss.update(clp_loss)

        with torch.no_grad():
            x_rec = self.model.decoder(torch.stack([mu1, mu2]))  # type: ignore
            self.accuracy.update(
                ((x_rec > 0) == torch.stack([batch.normal, batch.negate]))
                .float()
                .prod(dim=0)
                .mean()
            )

        return tot_vae_loss, clp_loss

    def _run_epoch(
        self,
        epoch: int,
        dataloader: DataLoader,
        train: bool = True,
    ):
        self.vae_loss.reset()
        self.clp_loss.reset()
        self.accuracy.reset()

        desc = self._get_desc(epoch, train, False)
        with tqdm(unit="batch", total=len(dataloader), desc=desc) as progress_bar:
            for iter, td in enumerate(dataloader):
                self._run_batch(td, train)
                progress_bar.update()

                # update bar every 100 batches
                if iter % 100 == 0:
                    desc = self._get_desc(epoch, train, True)
                    progress_bar.set_description(desc)
                    progress_bar.refresh()

                if train and iter % self.step_logging == 0:
                    self._step_log_info()

            # log final loss
            desc = self._get_desc(epoch, train, True)
            progress_bar.set_description(desc)
            progress_bar.refresh()

    def _get_desc(
        self,
        epoch: int,
        train: bool,
        include_loss: bool = False,
    ) -> str:
        step_type = "Train" if train else "Eval "

        desc = (
            f"[{step_type}][Epoch: {epoch:4} / {self.n_total_epochs}]"
            f"[Total time: {self._get_elapsed_time()}]"
        )
        if include_loss:
            desc += f"[VAE Loss: {round(self.vae_loss.compute().item(), 3):3f}]"
            desc += f"[CLIP Loss: {round(self.clp_loss.compute().item(), 3):3f}]"
            desc += f"[Accuracy: {round(self.accuracy.compute().item(), 3):3f}]"

        return desc

    def _get_elapsed_time(self) -> str:
        return time_formatter(
            time.time() - self.start_time,
            show_ms=False,
        )

    def _log_info(self, epoch: int, train: bool) -> dict[str, float]:
        if train:
            step_type = "train"
        else:
            step_type = "eval"

        epoch_vae_loss = self.vae_loss.compute().item()
        epoch_clp_loss = self.clp_loss.compute().item()
        epoch_accuracy = self.accuracy.compute().item()

        metrics = {
            f"{step_type}/VAE_Loss": epoch_vae_loss,
            f"{step_type}/CLIP_Loss": epoch_clp_loss,
            f"{step_type}/Accuracy": epoch_accuracy,
            f"{step_type}/Total_Loss": epoch_vae_loss + epoch_clp_loss,
            "epoch": epoch,
        }
        self.exp_logger.log_metric(
            metrics, step_key="epoch", commit=step_type == "eval"
        )
        return metrics

    def _step_log_info(self):
        train_epoch_vae_loss = self.vae_loss.compute().item()
        train_epoch_clip_loss = self.clp_loss.compute().item()

        metrics = {
            "train/VAE_Loss x-total": train_epoch_vae_loss,
            "train/CLIP_Losss x-total": train_epoch_clip_loss,
            "train/Total_Loss x-total": train_epoch_vae_loss + train_epoch_clip_loss,
            "x-total": self.n_seen_points,
        }
        self.exp_logger.log_metric(metrics, step_key="x-total", commit=True)

    def _log_params(self):
        self.exp_logger.log_param("input_dim", self.model.input_dim)

        layer_dims = []
        for layer in self.model.encoder.children():  # type: ignore
            if isinstance(layer, torch.nn.Sequential):
                for l in layer.children():
                    if isinstance(l, torch.nn.Linear):
                        layer_dims.append(l.out_features)
        for i, dim in enumerate(layer_dims):

            self.exp_logger.log_param(f"layer_{i}_dim", dim)

        self.exp_logger.log_param("latent_dim", self.model.latent_dim)

        # Log total number of parameters
        total_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        self.exp_logger.log_param("total_parameters", total_params)

    def _run_batch(self, batch: TTbatch, train: bool = True) -> None:
        with torch.set_grad_enabled(train), torch.autocast(
            device_type=self.device.type,
            dtype=self.dtype,
            enabled=self.config.use_amp,
        ):
            # model forward pass
            batch = batch.to(self.device, non_blocking=True)

            vae_loss, clp_loss = self._model_loss(batch)

            # loss
            loss = vae_loss + clp_loss

            # optimization step
            if train:
                self.n_seen_points += len(batch)
                self.optimizer.zero_grad(set_to_none=True)
                loss.backward()

                # torch.nn.utils.clip_grad_norm_(
                #     self.model.parameters(), self.config.grad_norm_clip
                # )
                self.optimizer.step()

    def _prepare_dataloader(self, dataset: Dataset, train: bool = True) -> DataLoader:
        collate_fn = EmbTrainCollate(
            device=self.device,
            dtype=self.dtype,
        )
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=train,
            num_workers=self.config.dl_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            drop_last=True,
            persistent_workers=self.config.dl_workers > 0,
        )

    def train(self):
        self.console_logger.info("Starting Training")
        for epoch in range(self.epochs_run, self.n_total_epochs):
            epoch += 1
            self._run_epoch(epoch, self.train_loader, train=True)

            # log train info
            self._log_info(epoch, train=True)

            # eval run
            self._run_epoch(epoch, self.eval_loader, train=False)

            # log eval info
            metrics = self._log_info(epoch, train=False)

            # adjust learning rate
            self.lr_scheduler.step()

            # save train model
            self.checkpoint_manager.save_checkpoint(epoch, metrics)

        self.exp_logger.end_run()
