import os
import random
import logging
from datetime import datetime
from typing import Dict, Any
import numpy as np
import torch


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def setup_logging(log_dir: str, experiment_name: str = None) -> logging.Logger:
    os.makedirs(log_dir, exist_ok=True)
    if experiment_name is None:
        experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"{experiment_name}.log")

    logger = logging.getLogger("TAP")
    logger.setLevel(logging.INFO)
    logger.propagate = False
    logger.handlers.clear()

    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)

    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    return logger


def save_checkpoint(
    checkpoint_dir: str,
    step: int,
    agent,
    env_state: Dict[str, Any],
    best_metric: float = None
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint = {
        'step': step,
        'policy_state_dict': agent.policy.state_dict(),
        'optimizer_state_dict': agent.optimizer.state_dict(),
        'z_0': agent.z_0,
        'env_state': env_state,
        'best_metric': best_metric
    }
    if hasattr(agent.ref_policy, 'state_dict'):
        checkpoint['ref_policy_state_dict'] = agent.ref_policy.state_dict()

    torch.save(checkpoint, os.path.join(checkpoint_dir, "latest_checkpoint.pt"))
    torch.save(checkpoint, os.path.join(checkpoint_dir, f"checkpoint_step{step}.pt"))


def load_checkpoint(checkpoint_path: str, agent, device: str = "cpu") -> Dict[str, Any]:
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)
    agent.policy.load_state_dict(checkpoint['policy_state_dict'])
    if 'ref_policy_state_dict' in checkpoint and hasattr(agent.ref_policy, 'load_state_dict'):
        agent.ref_policy.load_state_dict(checkpoint['ref_policy_state_dict'])
    agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    agent.z_0 = checkpoint.get('z_0', 0.0)

    return {
        'step': checkpoint['step'],
        'env_state': checkpoint['env_state'],
        'best_metric': checkpoint.get('best_metric', None)
    }
