"""
@date: 2026-01-02
@desc: Efficient implementation of multiple ordered gradient-routing runs via checkpointing

export OMP_NUM_THREADS=16 && torchrun --nproc_per_node=8 -m src.run.orchestrate.experiment.exp_09_datamixing
"""

import torch
from copy import deepcopy
from pathlib import Path
from typing import Optional
import warnings

warnings.filterwarnings("ignore", message=".*Online softmax is disabled.*", category=UserWarning)

from src.model.config import ModelConfig
from src.model.moe import MoETransformer
from src.model.lora import LoRATransformer
from src.run.train.routed import do_routed_ordered
from src.run.config import RunConfig, setup
from src.run.eval import do_eval
from src.run.utils import (
    make_model, 
    copy_model, 
    get_timestamp, 
    get_routed_dims, 
    calc_lora_rank,
    convert_ordered_args,
)
from src.run.distributed import (
    setup_distributed,
    cleanup_distributed,
    get_raw_model,
)

def run_ordered_fast(
    stage: dict,
    ordered_configs: list[dict],
    model_config: ModelConfig,
    run_config: RunConfig,
) -> None:

    device = run_config.device
    logger = run_config.logger
    do_compile = run_config.do_compile
    aux_labels = set(run_config.aux_labels)
    all_labels = aux_labels | {"core"}
    arch = stage.get("arch", "lora")
    len_aux = sum([len(run_config.loaders[label]["train"]) for label in aux_labels])
    len_core = len(run_config.loaders["core"]["train"])

    forget_targets = [aux_labels - {x} for x in aux_labels] + [aux_labels]
    forget_targets = [x for x in forget_targets if len(x) > 0]
    forget_targets = sorted(set([tuple(sorted(x)) for x in forget_targets]))

    expert_dist = stage.get("expert_dist", "add")
    mlp_dim, aux_dim = get_routed_dims(model_config, run_config, expert_dist)

    if arch == "moe":

        logger.info(f"ROUTED - MLP dim: {mlp_dim}, AUX dim: {aux_dim}")
        model = make_model(
            MoETransformer,
            model_config,
            run_config,
            extra_args={
                "mlp_dim": mlp_dim,
                "aux_dim": aux_dim,
            }
        )

    elif arch == "lora":

        lora_attn = stage.get("lora_attn", False)
        lora_mlp = stage.get("lora_mlp", True)
        lora_rank = stage.get("lora_rank", -1)
        if lora_rank == -1:
            lora_rank = calc_lora_rank(model_config, lora_attn, lora_mlp, mlp_dim, aux_dim)
        logger.info(f"ROUTED - MLP dim: {mlp_dim}, LoRA rank: {lora_rank}")
        model = make_model(
            LoRATransformer,
            model_config,
            run_config,
            extra_args={
                "mlp_dim": mlp_dim,
                "lora_rank": lora_rank,
                "lora_attn": lora_attn,
                "lora_mlp": lora_mlp,
            }
        )

    logger.info(f"Num Configs: {len(ordered_configs)}")

    #ensure every config has alpha, beta, aux_prc, core_prc
    for config in ordered_configs:
        args = convert_ordered_args(
            alpha=config.get("alpha"),
            beta=config.get("beta"),
            aux_prc=config.get("aux_prc"),
            core_prc=config.get("core_prc"),
            len_aux=len_aux, 
            len_core=len_core
        )
        config.update(args)

    core_2_config = {}
    for config in ordered_configs:
        core_prc = config['core_prc']
        if core_prc not in core_2_config:
            core_2_config[core_prc] = []
        core_2_config[core_prc].append(config)
    
    core_prcs = sorted(core_2_config.keys())
    saved_model = model
    saved_state = {}

    num_seen_configs = 0
    for core_prc in core_prcs:

        logger.info(f"core_prc: {core_prc}")

        cur_configs = core_2_config[core_prc]
        cur_configs = sorted(cur_configs, key=lambda x: x['aux_prc'])

        saved_model, saved_state = do_routed_ordered(
            model=saved_model,
            config=run_config,
            core_prc=core_prc,
            ft_aux_prc=cur_configs[0]['aux_prc'],
            state=saved_state,
            do_ft=False,
        )
                
        logger.info(
            f"BASE PHASE COMPLETE: core_prc={core_prc:.4f}, "
            f"scheduler_last_epoch={saved_state.get('scheduler_epoch', 'N/A')}"
        )

        saved_model = get_raw_model(saved_model).to("cpu", dtype=torch.bfloat16)

        for config in cur_configs:

            aux_prc = config['aux_prc']
            logger.info(f"CONFIG: core_prc: {core_prc}, aux_prc: {aux_prc}, beta: {config['beta']}, alpha: {config['alpha']}")
            stage.update(config)

            aux_model = copy_model(saved_model, device, do_compile)
            aux_state = deepcopy(saved_state)

            aux_model, aux_state = do_routed_ordered(
                model=aux_model,
                config=run_config,
                core_prc=core_prc,
                ft_aux_prc=aux_prc,
                state=aux_state,
                do_ft=True,
            )

            logger.info(
                f"FT PHASE COMPLETE: core_prc={core_prc:.4f}, aux_prc={aux_prc}, "
                f"scheduler_last_epoch={aux_state.get('scheduler_epoch', 'N/A')}"
            )

            num_seen_configs += 1
            logger.info(f"Start Eval: {num_seen_configs} / {len(ordered_configs)}")

            for ablate_labels in forget_targets:

                experts = sorted(all_labels - set(ablate_labels))
                ablate_labels = sorted(ablate_labels)

                logger.info(f"Experts: {experts}")
                logger.info(f"Ablating: {ablate_labels}")

                do_eval(
                    stage=stage,
                    model=aux_model,
                    config=run_config,
                    expert_labels=experts,
                    log={
                        "target": ablate_labels,
                        "finetune": None,
                        "elicited": False,
                    }
                )

            
            del aux_model
            torch.cuda.empty_cache()

        #now put saved_model back on gpu
        saved_model = copy_model(saved_model, device, do_compile)

    del saved_model
    torch.cuda.empty_cache()

def run(
    # stage config
    stage: dict,
    ordered_configs: list[dict],
    # model config
    ctx_len: int,
    num_layers: int,
    embed_dim: int,
    mlp_dim: int,
    # run config
    arbsub: bool,
    test_ood: bool,
    data_dirs: list[str],
    aux_labels: list[str],
    core_labels: list[str] | None,
    do_compile: bool,
    seed: int,
    res_dir: str,
    batch_size: int,
    epochs: int,
    lr: float,
    log_level: str,
    core_batch_limit: str | int | None,
    aux_batch_limit: int | float | None,
    lr_schedule: bool,
    do_cleanup_distributed: bool,
    optimize_routed_training: bool,
    accumulation_steps: int,
    timestamp: Optional[str] = None,
    process_id: Optional[int] = None,
) -> None:

    stages = []
    for config in ordered_configs:
        stage_copy = deepcopy(stage)
        stage_copy.update(config)
        stages.append(stage_copy)

    # Setup distributed training if launched with torchrun (idempotent)
    setup_distributed()
    
    try:
        configs = setup(
            # stage config
            stages=stages,
            # model config
            ctx_len=ctx_len,
            num_layers=num_layers,
            embed_dim=embed_dim,
            mlp_dim=mlp_dim,
            # run config
            arbsub=arbsub,
            test_ood=test_ood,
            data_dirs=data_dirs,
            aux_labels=aux_labels,
            core_labels=core_labels,
            do_compile=do_compile,
            seed=seed,
            res_dir=res_dir,
            timestamp=timestamp,
            batch_size=batch_size,
            epochs=epochs,
            lr=lr,
            core_batch_limit=core_batch_limit,
            aux_batch_limit=aux_batch_limit,
            log_level=log_level,
            lr_schedule=lr_schedule,
            process_id=process_id,
            accumulation_steps=accumulation_steps,
            optimize_routed_training=optimize_routed_training,
    )

        logger = configs["run_config"].logger
        loaders = configs["run_config"].loaders
        for key, loader in loaders.items():
            logger.info(f"Loader {key} train: {len(loader['train'])}, test: {len(loader['test'])}")

        run_ordered_fast(
            stage=stage,
            ordered_configs=ordered_configs,
            model_config=configs["model_config"],
            run_config=configs["run_config"],
        )

        logger = configs["run_config"].logger
        res_dir = configs["run_config"].res_dir
        logger.info("-" * 40)
        logger.info(f"Finished. See {res_dir}")
    
    finally:
        # Clean up distributed training (skip if caller manages lifecycle)
        if do_cleanup_distributed:
            cleanup_distributed()

if __name__ == "__main__":

    # clear cache
    torch.cuda.empty_cache()

    root_dir = Path('src').absolute()
    timestamp = get_timestamp()

    '''
    #--- Simple Stories Experiment --- #
    from src.run.orchestrate.config import StoriesBaseArgs
    num_aux_labels = 4
    data_dir = root_dir / "data/stories"
    import json
    dataset_metadata = json.load(open(data_dir / "metadata.json"))
    all_labels = dataset_metadata["all"]["labels"]
    base_args = deepcopy(StoriesBaseArgs)
    base_args['aux_labels'] = all_labels[:num_aux_labels]
    base_args['res_dir'] = root_dir / f"results/stories/09/results_{timestamp}"
    '''

    # --- Realistic Experiment --- #
    from src.run.orchestrate.config import RealisticBaseArgs, calc_realistic_model_params
    base_args = deepcopy(RealisticBaseArgs)
    base_args['aux_labels'] = ["bigcode", "biology", "nuclear", "cyber"]
    base_args['res_dir'] = root_dir / f"results/realistic/09/results_{timestamp}"
    model_params = calc_realistic_model_params(700e6)
    base_args.update(model_params)

    #DEBUG args
    # aux_prcs = [0.5, 0.6]
    # core_prcs = [0.9, 0.95]
    # base_args['do_compile'] = False
    # model_params = calc_realistic_model_params(50e6)
    # base_args.update(model_params)
    # base_args['core_batch_limit'] = 2000
    # base_args['aux_batch_limit'] = 100
    # base_args['optimize_routed_training'] = True
    
    # aux_prcs  = [0.5, 0.7, 0.9] #proportion of ft that is aux
    # core_prcs = [0.7, 0.8, 0.9] #proportion of core-only run
    
    # configs = []
    # for aux_prc in aux_prcs:
    #     for core_prc in core_prcs:
    #         configs += [{"aux_prc": aux_prc, "core_prc": core_prc}]

    # configs += [{"aux_prc": 1.0, "core_prc": 1.0}] #just normal LoRA

    configs = [
        {"core_prc": 0.7, "aux_prc": 0.3},
        {"core_prc": 0.7, "aux_prc": 0.7},
        {"core_prc": 0.8, "aux_prc": 0.5},
        {"core_prc": 0.8, "aux_prc": 0.7},
        {"core_prc": 0.8, "aux_prc": 0.9},
        {"core_prc": 0.8, "aux_prc": 1.0},
    ]

    stage = {
        "name": "routed", "ft_forget": False,
        "arch": "lora", "ordered": True,
        "lora_attn": True, "lora_mlp": True,
        "expert_dist": "prc_one",
        "equal_compute": True,
        "do_checkpoint": False,
    }

    del base_args['stages']
    base_args['stage'] = stage
    base_args['ordered_configs'] = configs
    base_args['timestamp'] = timestamp
    base_args['do_cleanup_distributed'] = True
    base_args['log_level'] = "DEBUG"
    base_args['seed'] = 42

    run(**base_args)