

import os
import shutil
import yaml
import numpy as np
from datetime import datetime
from typing import Optional, Dict, Any

import torch
from torch.utils.tensorboard import SummaryWriter
from pipeline.registry import registry


class ExperimentManager:
    

    def __init__(
        self,
        name: str,
        base_dir: str = "experiments",
        timestamp: Optional[str] = None,
        config: Optional[Dict[str, Any]] = None
    ):
        
        self.name = name
        self.base_dir = base_dir

        if timestamp:
            self.timestamp = timestamp
        else:
            self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        self.exp_dir = os.path.join(base_dir, f"{name}_{self.timestamp}")
        self.tensorboard_dir = os.path.join(self.exp_dir, "tensorboard")
        self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoints")
        self.log_dir = os.path.join(self.exp_dir, "logs")

        os.makedirs(self.tensorboard_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)

        if config:
            self.save_config(config)

        self._writer: Optional[SummaryWriter] = None
        self._log_file = None

        print(f"[Experiment] Experiment dir: {self.exp_dir}")
        print(f"[Experiment] TensorBoard: {self.tensorboard_dir}")
        print(f"[Experiment] Checkpoints: {self.checkpoint_dir}")
        print(f"[Experiment] Logs: {self.log_dir}")

    def save_config(self, config: Dict[str, Any]):
        
        config_path = os.path.join(self.exp_dir, "config.yaml")
        with open(config_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
        print(f"[Experiment] Config saved: {config_path}")

    @property
    def writer(self) -> SummaryWriter:
        
        if self._writer is None:
            self._writer = SummaryWriter(self.tensorboard_dir)
        return self._writer

    def log_scalar(self, tag: str, value: float, step: int):
        
        self.writer.add_scalar(tag, value, step)

    def log_scalars(self, log_dict: Dict[str, float], step: int):
        
        for tag, value in log_dict.items():
            self.writer.add_scalar(tag, value, step)

    def log(self, log_dict: Dict[str, float], step: int = None):
        
        for tag, value in log_dict.items():
            self.writer.add_scalar(tag, value, step)

    def log_image(self, tag: str, img_tensor, step: int):
        
        if img_tensor is not None:
            self.writer.add_image(tag, img_tensor, step)

    def log_histogram(self, tag: str, values, step: int):
        
        if values is not None:
            try:
                if isinstance(values, torch.Tensor):
                    values = values.detach().cpu().numpy()
                if isinstance(values, np.ndarray):
                    values = np.ascontiguousarray(values.flatten(), dtype=np.float64)
                self.writer.add_histogram(tag, values, step)
            except TypeError:
                pass

    def log_figure(self, tag: str, figure, step: int):
        
        if figure is not None:
            self.writer.add_figure(tag, figure, step)

    def log_text(self, message: str, also_print: bool = True):
        
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_message = f"[{timestamp}] {message}"

        log_path = os.path.join(self.log_dir, "train.log")
        with open(log_path, 'a') as f:
            f.write(log_message + "\n")

        if also_print:
            print(log_message)

    def save_checkpoint(
        self,
        state_dict: Dict[str, Any],
        is_best: bool = False,
        filename: str = "model_latest.pth"
    ):
        
        save_path = os.path.join(self.checkpoint_dir, filename)
        torch.save(state_dict, save_path)

        if is_best:
            best_path = os.path.join(self.checkpoint_dir, "model_best.pth")
            shutil.copy(save_path, best_path)
            self.log_text(f"Saved best model: {best_path}")

    def load_checkpoint(self, filename: str = "model_best.pth") -> Dict[str, Any]:
        
        load_path = os.path.join(self.checkpoint_dir, filename)
        if os.path.exists(load_path):
            return torch.load(load_path)
        else:
            raise FileNotFoundError(f"Checkpoint not found: {load_path}")

    def get_checkpoint_path(self, filename: str = "model_best.pth") -> str:
        
        return os.path.join(self.checkpoint_dir, filename)

    def flush(self):
        
        if self._writer is not None:
            self._writer.flush()

    def close(self):
        
        if self._writer is not None:
            self._writer.close()
            self._writer = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


@registry.register_utils("experiment_manager")
class ExperimentManagerWrapper:
    

    def __init__(self, cfg: Dict[str, Any]):
        exp_config = cfg.get('experiment', {})
        name = exp_config.get('name', cfg.get('pipeline', {}).get('name', 'experiment'))
        base_dir = exp_config.get('base_dir', 'experiments')
        timestamp = exp_config.get('timestamp', None)

        self.manager = ExperimentManager(
            name=name,
            base_dir=base_dir,
            timestamp=timestamp,
            config=cfg
        )

    def __getattr__(self, name):
        return getattr(self.manager, name)


def create_experiment_from_config(cfg: Dict[str, Any]) -> ExperimentManager:
    
    exp_config = cfg.get('experiment', {})

    default_name = cfg.get('task', 'experiment')

    name = exp_config.get('name', default_name)
    base_dir = exp_config.get('base_dir', 'experiments')
    timestamp = exp_config.get('timestamp', None)

    return ExperimentManager(
        name=name,
        base_dir=base_dir,
        timestamp=timestamp,
        config=cfg
    )
