import json
import logging
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch
from transformers import AutoTokenizer
from transformers.utils import logging as hf_logging

from src.model.config import ModelConfig
from src.run.dataloader import DataLoader, make_loaders, get_labels_batch_count, auto_detect_categories
from src.run.utils import ensure_dir, get_timestamp, set_seeds
from src.run.logger import setup_logger
from src.run.distributed import get_world_size, is_main_process, barrier, broadcast_object
from src.model.base import BaseTransformer

# --------------------------------------------------------------------------- #
# global log suppression for TorchDynamo recompilation warnings               #
# --------------------------------------------------------------------------- #


# Silence TorchDynamo recompilation warnings without setting invalid TORCH_LOGS
warnings.filterwarnings(
    "ignore",
    message=r".*torch\._dynamo.*recompile_limit.*",
    category=UserWarning,
)

# Silence Dynamo warnings about DDP's _broadcast_coalesced (can't trace through DDP internals)
warnings.filterwarnings(
    "ignore",
    message=r".*_broadcast_coalesced.*",
    category=UserWarning,
)

# Reduce logger verbosity for torch._dynamo in the current process
logging.getLogger("torch._dynamo").setLevel(logging.ERROR)
hf_logging.set_verbosity_error()


# --------------------------------------------------------------------------- #
# helpers                                                                     #
# --------------------------------------------------------------------------- #


def validate_stages(stages: list[dict]) -> None:
    """Validate stage dependencies."""

    stage_names = [x["name"] for x in stages]
    acceptable_stages = [
        'baseline',
        'rmu',
        'ascent',
        'maxent',
        'filtering',
        'coreftaux',
        'routed',
    ]

    assert all(stage in acceptable_stages for stage in stage_names), f"Invalid stages: {stage_names}"

    if any(x in stage_names for x in ["rmu", "gradient_ascent", "maxent"]):
        assert "baseline" in stage_names, "Baseline model is required for posthoc unlearning"


def setup_tokenizer(data_dir_path: Path, logger: logging.Logger) -> AutoTokenizer:
    """Setup tokenizer from metadata."""
    
    metadata_path = data_dir_path / "metadata.json"
    if not metadata_path.exists():
        raise FileNotFoundError(f"metadata.json not found in {data_dir_path}. Please run prepare.py first.")
    with open(metadata_path, "r") as f:
        metadata = json.load(f)

    tokenizer_name = metadata["all"].get("tokenizer")
    if tokenizer_name is None:
        tokenizer_name = "EleutherAI/gpt-neo-125M"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    vocab_size = len(tokenizer)
    logger.info(f"Tokenizer vocabulary: {vocab_size}")

    metadata_vocab_size = metadata["all"].get("vocab_size")
    assert vocab_size == metadata_vocab_size, "Vocab size mismatch between tokenizer and metadata"

    return tokenizer


@dataclass
class RunConfig:
    loaders: dict[str, DataLoader]
    res_dir: Path
    epochs: int
    lr: float
    batch_size: int
    data_dirs: list[Path]
    aux_labels: list[str]
    core_labels: list[str]
    lr_schedule: bool
    device: str
    timestamp: str
    seed: int
    num_gpus: int
    logger: logging.Logger
    do_compile: bool
    core_batch_num: int
    aux_batch_num: int
    arbsub: bool
    test_ood: bool
    accumulation_steps: int
    optimize_routed_training: bool
    num_baseline_params: int

def save_config(
    stages: list[dict],
    model_config: ModelConfig,
    run_config: RunConfig,
) -> None:
    """Save configuration to JSON file (only on main process in distributed mode)."""
    from src.run.distributed import is_main_process
    
    # Only save and log from main process
    if not is_main_process():
        return
    
    logger = run_config.logger
    len_core = len(run_config.loaders["core"]["train"])
    len_aux = sum([len(run_config.loaders[label]["train"]) for label in run_config.aux_labels])

    config_data = {
        "stages": stages,
        "run": {
            "seed": run_config.seed,
            "data_dirs": run_config.data_dirs,
            "arbsub": run_config.arbsub,
            "test_ood": run_config.test_ood,
            "batch_size": run_config.batch_size,
            "epochs": run_config.epochs,
            "lr": run_config.lr,
            "aux_labels": run_config.aux_labels,
            "core_labels": run_config.core_labels,
            "lr_schedule": run_config.lr_schedule,
            "device": run_config.device,
            "timestamp": run_config.timestamp,
            "num_gpus": run_config.num_gpus,
            "do_compile": run_config.do_compile,
            "core_batch_num": run_config.core_batch_num,
            "aux_batch_num": run_config.aux_batch_num,
            "accumulation_steps": run_config.accumulation_steps,
            "optimize_routed_training": run_config.optimize_routed_training,
            "num_baseline_params": run_config.num_baseline_params,
            "len_core": len_core,
            "len_aux": len_aux,
        },
        "model": {
            "ctx_len": model_config.ctx_len,
            "vocab_size": model_config.vocab_size,
            "num_layers": model_config.num_layers,
            "target_layers": model_config.target_layers,
            "num_heads": model_config.num_heads,
            "num_key_value": model_config.num_key_value,
            "attn_bias": model_config.attn_bias,
            "eos_token_id": model_config.eos_token_id,
            "embed_dim": model_config.embed_dim,
            "mlp_dim": model_config.mlp_dim,
        },
    }

    out_str = json.dumps(config_data, indent=4, ensure_ascii=False, default=str)
    with open(run_config.res_dir / "config.json", "w") as f:
        f.write(out_str)

    logger.info(f"Saved configuration to {run_config.res_dir}/config.json")
    logger.info(out_str)


def setup(

    # stage config
    stages: list[dict],

    # model config
    ctx_len: int,
    num_layers: int,
    embed_dim: int,
    mlp_dim: int,

    # run config
    arbsub: bool,
    test_ood: bool,
    data_dirs: list[str],
    aux_labels: list[str],
    core_labels: list[str] | None,
    do_compile: bool,
    seed: int,
    res_dir: str,
    
    batch_size: int,
    epochs: int,
    lr: float,
    log_level: str,
    lr_schedule: bool,
    aux_batch_limit: int | float | None,
    core_batch_limit: str | int | float | None,
    accumulation_steps: int,
    optimize_routed_training: bool,
    timestamp: Optional[str] = None,
    process_id: Optional[int] = None,

) -> dict[str, RunConfig | ModelConfig]:

    assert "core" not in aux_labels, "core cannot be an aux label"
    assert len(data_dirs) > 0, "data_dirs must be provided"

    # Validate stage dependencies
    validate_stages(stages)

    if seed == -1:
        seed = int(time.time()) if is_main_process() else None
        seed = broadcast_object(seed, src=0)
        
    set_seeds(seed)

    # CUDA setup
    assert torch.cuda.is_available(), "CUDA is not available"
    device = torch.device("cuda")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

    res_dir = Path(res_dir)
    if timestamp is None:
        timestamp = get_timestamp()

    # Only create directory on main process to avoid duplicates in multigpu mode
    if is_main_process():
        ensure_dir(res_dir)
    
    # Ensure all processes wait for directory creation and use the same path
    barrier()

    # Get number of GPUs
    num_gpus = get_world_size()
    
    # Setup logger
    log_file = res_dir / "training.log"
    logger = setup_logger(
        name=f"training_{timestamp}",
        log_file=log_file,
        level=log_level,
        process_id=process_id,
    )
    
    logger.info(f"Number of GPUs: {num_gpus}")

    # Setup tokenizer
    tokenizer = setup_tokenizer(data_dirs[0], logger)

    # Create model configuration
    model_config = ModelConfig(
        tokenizer=tokenizer,
        embed_dim=embed_dim,
        ctx_len=ctx_len,
        vocab_size=len(tokenizer),
        num_layers=num_layers,
        target_layers=list(range(num_layers)),
        num_heads=8,
        num_key_value=2,
        attn_bias=True,
        mlp_dim=mlp_dim,
        eos_token_id=tokenizer.eos_token_id,
        aux_labels=aux_labels,
    )

    temp_model = BaseTransformer(model_config)
    num_baseline_params = sum(p.numel() for p in temp_model.parameters())
    del temp_model

    data_dir_paths = [Path(d) for d in data_dirs]
    categories = {}
    for data_dir in data_dir_paths:
        categories.update(auto_detect_categories(data_dir))
    all_labels = sorted(categories.keys())
    
    if core_labels is None:
        core_labels = sorted(set(all_labels) - set(aux_labels))

    # Core batch limit

    max_core_batch_num = get_labels_batch_count(
        data_dirs=data_dir_paths,
        labels=core_labels,
        B=batch_size,
        T=ctx_len,
        num_processes=num_gpus,
    )

    if core_batch_limit == "optimal":
        
        logger.info(f"Baseline model has {num_baseline_params:,} parameters")
        optimal_tokens = num_baseline_params * 20 #chincilla optimal core batch limit 
        # treat optimal batches per rank, not global batches
        optimal_batches = optimal_tokens / (ctx_len * batch_size * num_gpus)
        logger.info(f"Chincilla optimal core tokens: {optimal_tokens:,}, batch limit: {optimal_batches:.4f}")
        core_batch_num = int(round(optimal_batches))
        if core_batch_num > max_core_batch_num:
            logger.warning(f"Optimal core batch limit {core_batch_num} is greater than max core batch num {max_core_batch_num}")

    elif type(core_batch_limit) == float:
        core_batch_prc = core_batch_limit
        core_batch_num = int(round(max_core_batch_num * core_batch_prc))

    elif type(core_batch_limit) == int:
        core_batch_num = core_batch_limit

    elif core_batch_limit is None:
        # Pass None to indicate no limit
        core_batch_num = None

    else:
        raise ValueError(f"Invalid core batch limit: {core_batch_limit}")

    # Aux batch limit
        
    if type(aux_batch_limit) == float:
        aux_batch_prc = aux_batch_limit
        # Use max_core_batch_num as base when core has no limit (core_batch_num is None)
        base_batch_num = core_batch_num if core_batch_num is not None else max_core_batch_num
        aux_batch_num = int(round(base_batch_num * aux_batch_prc))

    elif type(aux_batch_limit) == int:
        aux_batch_num = aux_batch_limit

    elif aux_batch_limit is None:
        # Pass None to indicate no limit
        aux_batch_num = None

    else:
        raise ValueError(f"Invalid aux batch limit: {aux_batch_limit}")

    logger.debug(f"core_batch_num per rank: {core_batch_num}, aux_batch_num per rank: {aux_batch_num}")

    # Setup data loaders
    loaders, core_labels = make_loaders(
        data_dirs=data_dirs,
        aux_labels=aux_labels,
        core_labels=core_labels,
        B=batch_size,
        T=ctx_len,
        seed=seed,
        device=device,
        core_batch_num=core_batch_num,
        aux_batch_num=aux_batch_num,
        max_num_test=200,
    )

    if aux_batch_num is None:
        aux_batch_num = sum([len(loaders[label]["train"]) for label in aux_labels])

    if core_batch_num is None:
        core_batch_num = len(loaders["core"]["train"])

    # Create run configuration
    run_config = RunConfig(
        loaders=loaders,
        res_dir=res_dir,
        epochs=epochs,
        lr=lr,
        batch_size=batch_size,
        data_dirs=data_dirs,
        aux_labels=list(aux_labels),
        core_labels=list(core_labels),
        lr_schedule=lr_schedule,
        device=device,
        timestamp=timestamp,
        seed=seed,
        num_gpus=num_gpus,
        logger=logger,
        do_compile=do_compile,
        core_batch_num=core_batch_num,
        aux_batch_num=aux_batch_num,
        arbsub=arbsub,
        test_ood=test_ood,
        accumulation_steps=accumulation_steps,
        optimize_routed_training=optimize_routed_training,
        num_baseline_params=num_baseline_params,
    )

    # Save configuration
    save_config(stages, model_config, run_config)

    return {
        "model_config": model_config,
        "run_config": run_config,
    }