import os
import torch
import wandb
import random
import logging
import numpy as np
from torch import nn
from typing import Optional
from omegaconf import OmegaConf, DictConfig
from hydra.core.hydra_config import HydraConfig


def seed_everything(seed: int) -> int:
    max_seed_value = np.iinfo(np.uint32).max
    min_seed_value = np.iinfo(np.uint32).min

    if not (min_seed_value <= seed <= max_seed_value):
        print(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
        seed = random.randint(min_seed_value, max_seed_value)
        
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

def set_log_dir(cfg: DictConfig) -> str:
    assert cfg.log_version is not None, "👀 Please set log_version in the config file for model saving."
    log_dir = os.path.join(cfg['log_dir'], cfg['benchmark'], cfg['log_version'])
    
    # --- Set up logging ---
    setup_logger(log_dir)
    print(f"📁 log_dir: {log_dir}")
    os.makedirs(log_dir, exist_ok=True)
    OmegaConf.save(config=cfg, f=os.path.join(log_dir, "config.yaml"))
    OmegaConf.save(config=HydraConfig.get(), f=os.path.join(log_dir, "hydra_config.yaml"))
    
    # --- Set up wandb ---
    wandb.init(
        entity=cfg.get('wandb_entity', None),
        project=cfg.get('wandb_project', None),
        name=cfg['log_version'],
        config=OmegaConf.to_container(cfg, resolve=True),
    )
    print("⛓️ WandB initialized")
    return log_dir
    

def setup_logger(log_dir: str, name: str = "training_steps_log.log") -> None:
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, name)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    if root_logger.hasHandlers():
        root_logger.handlers.clear()

    file_handler = logging.FileHandler(log_path, mode="a")
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(logging.Formatter(
        '%(message)s'
    ))
    root_logger.addHandler(file_handler)
    

def save_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    epoch: int,
    model_config: DictConfig,
    log_dir: str,
    modality: Optional[str] = None
) -> None:
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'config': model_config
    }
    save_path = os.path.join(log_dir, f"checkpoint_epoch_{epoch+1}.pth")
    if modality:
        save_path = os.path.join(log_dir, f"{modality}_checkpoint_epoch_{epoch+1}.pth")
    torch.save(checkpoint, save_path)