"""
@date: 2025-12-19

main.py - experimental multiclass gradient-routing pipeline.

Outputs to src/results/results_YYYY-MM-DD_HH-MM-SS
"""

import argparse
import os
import torch
import json
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, Any, Optional
import warnings

warnings.filterwarnings("ignore", message=".*Online softmax is disabled.*", category=UserWarning)

from src.model.config import ModelConfig, Transformer
from src.model.base import BaseTransformer
from src.model.moe import MoETransformer
from src.model.lora import LoRATransformer
from src.model.demix import DemixTransformer

from src.run.train.routed import do_routed_ordered, do_routed_unordered, do_routed_unordered_arbsub
from src.run.train.demix import do_demix
from src.run.train.base import do_train
from src.run.train.finetune import do_finetune
from src.run.train.coreftaux import do_coreftaux
from src.run.train.ascent import do_gradient_ascent
from src.run.train.maxent import do_maxent
from src.run.train.rmu import do_rmu

from src.run.config import RunConfig, setup
from src.run.eval import do_eval
from src.run.utils import (
    make_model, 
    copy_model, 
    get_timestamp, 
    get_routed_dims, 
    calc_lora_rank,
    convert_ordered_args,
    restore_state,
)
from src.run.distributed import (
    setup_distributed,
    cleanup_distributed,
    get_raw_model,
)

def run_experiment(
    stage: dict,
    model: Transformer,
    config: RunConfig,
    func: Callable,
    func_args: Optional[dict] = None,
    eval_args: Optional[dict] = None,
) -> tuple[Transformer, dict | None]:
    """
    Run a training stage with optional checkpoint loading and train/eval gating.
    
    Stage config options:
        checkpoint: Optional[str] - path to load weights from before training
        do_train: bool (default True) - whether to run the training function
        do_eval: bool (default True) - whether to run evaluation after
    """
    logger = config.logger
    device = config.device

    if func_args is None:
        func_args = dict()
    if eval_args is None:
        eval_args = dict()

    # Load checkpoint if specified (returns full state for resumption)
    state = func_args.get("state", {})
    checkpoint_path = stage.get("checkpoint")
    if checkpoint_path is not None:
        model, state = restore_state(model, checkpoint_path, device, logger) #loads model weights and returns state dict
        func_args["state"] = state
        logger.info(f"Loaded checkpoint from {checkpoint_path}")

    if stage.get("do_train", True):
        model, state = func(
            model=model,
            config=config,
            **func_args,
        )

    if stage.get("do_eval", True):
        do_eval(
            stage=stage,
            model=model,
            config=config,
            **eval_args,
        )

    return model, state

def run_experiments(
    stages: list[dict],
    model_config: ModelConfig,
    run_config: RunConfig,
) -> None:

    device = run_config.device
    logger = run_config.logger
    do_compile = run_config.do_compile
    arbsub = run_config.arbsub
    aux_labels = set(run_config.aux_labels)
    all_labels = aux_labels | {"core"}
    len_core = len(run_config.loaders["core"]["train"])
    len_aux = sum([len(run_config.loaders[label]["train"]) for label in aux_labels])

    if arbsub:
        forget_targets = [aux_labels - {x} for x in aux_labels] + [{x} for x in aux_labels] + [aux_labels, set()]
    else:
        #all labels minus each individual label, plus all labels
        forget_targets = [aux_labels - {x} for x in aux_labels] + [aux_labels]
        forget_targets = [x for x in forget_targets if len(x) > 0]
    
    forget_targets = sorted(set([tuple(sorted(x)) for x in forget_targets]))

    #ensure baseline is first
    stages = sorted(stages, key=lambda x: 0 if x["name"] == "baseline" else 1)
    baseline_model = None

    logger.debug(f"forget_targets: {forget_targets}")

    for stage in stages:

        logger.info(f"Stage Start: {json.dumps(stage, default=str, ensure_ascii=False)}")

        stage_name = stage["name"]

        # ---------------------- BASELINE ---------------------- #

        if stage_name == "baseline":

            logger.info("BASELINE START")

            baseline_model = make_model(BaseTransformer, model_config, run_config)
            save_args = {
                "dir": "baseline",
                "prefix": "baseline",
                "do_checkpoint": stage.get("do_checkpoint", False),
            }

            baseline_model = run_experiment(
                stage=stage,
                model=baseline_model,
                config=run_config,
                func=do_train,
                func_args={
                    "data_labels": ["all"],
                    "save_args": save_args,
                },
                eval_args={"log": {
                    "target": [],
                    "finetune": None,
                    "elicited": False,
                }},
            )[0]

            # Put baseline model on cpu - still need weights but don't wanna use vram
            baseline_model = get_raw_model(baseline_model).to("cpu", dtype=torch.bfloat16)

            do_ft_forget = stage.get("ft_forget", False)
            if do_ft_forget:

                logger.info("BASELINE - ADVERSARIAL FT START")

                for label in sorted(aux_labels):

                    ft_forget_model = copy_model(baseline_model, device, do_compile)

                    run_experiment(
                        stage=stage,
                        model=ft_forget_model,
                        config=run_config,
                        func=do_finetune,
                        func_args={"data_labels": [label]},
                        eval_args={
                            "data_labels": [label],
                            "log": {
                                "target": [],
                                "finetune": label,
                                "elicited": True,
                        }},
                    )

                    del ft_forget_model
                    torch.cuda.empty_cache()

        # ---------------------- POSTHOC UNLEARNING ---------------------- #

        if stage_name in ["rmu", "ascent", "maxent"]:

            logger.info(f"UNLEARNING - {stage_name} START")

            assert baseline_model is not None, "Baseline model is required for posthoc unlearning"

            for targets in forget_targets:

                if len(targets) == 0: continue

                model = copy_model(baseline_model, device, do_compile)

                funcs: Dict[str, Callable] = {
                    "rmu": do_rmu,
                    "ascent": do_gradient_ascent,
                    "maxent": do_maxent,
                }

                unlearn_func = funcs[stage_name]

                func_args = {"data_labels": targets}
                
                if stage_name == "rmu":
                    func_args["frozen_model"] = copy_model(model, device, do_compile).eval()
                    func_args["act_layer_inds"] = list(range(model_config.num_layers - 2))

                elif stage_name == "maxent":
                    me_steps = stage.get("me_steps", 400)
                    me_alpha_retain = stage.get("me_alpha_retain", 15.0)
                    func_args["me_steps"] = me_steps
                    func_args["me_alpha_retain"] = me_alpha_retain

                model = run_experiment(
                    stage=stage,
                    model=model,
                    config=run_config,
                    func=unlearn_func,
                    func_args=func_args,
                    eval_args={"log": {
                        "target": targets,
                        "finetune": None,
                        "elicited": False,
                    }},
                )[0]

                if stage_name == "rmu":
                    del func_args["frozen_model"]
                    torch.cuda.empty_cache()

                do_ft_forget = stage.get("ft_forget", False)
                if do_ft_forget:  # robust unlearning evaluation

                    logger.info(f"UNLEARNING - {stage_name} - ADVERSARIAL FT START")

                    model = get_raw_model(model).to("cpu", dtype=torch.bfloat16)

                    for label in targets:

                        ft_forget_model = copy_model(model, device, do_compile)

                        run_experiment(
                            stage=stage,
                            model=ft_forget_model,
                            config=run_config,
                            func=do_finetune,
                            func_args={"data_labels": [label]},
                            eval_args={
                                "data_labels": [label],
                                "log": {
                                    "target": targets,
                                    "finetune": label,
                                    "elicited": True,
                            }},
                        )

                        del ft_forget_model
                        torch.cuda.empty_cache()
                    
                    # Move model back to device for next target iteration
                    model = model.to(device, dtype=torch.bfloat16)

        # ---------------------- FILTERING ---------------------- #

        if stage_name == "filtering":

            logger.info("FILTERING START")

            for targets in forget_targets:

                if len(targets) == 0: continue

                filter_targets = stage.get("targets", [])
                filter_targets = [tuple(sorted(set(x))) for x in filter_targets]
                if len(filter_targets) > 0 and targets not in filter_targets:
                    continue

                logger.info(f"Filtering labels: {targets}")

                model = make_model(BaseTransformer, model_config, run_config)

                train_labels = all_labels - set(targets)

                if len(train_labels) == 0:
                    continue

                targets_str = "_".join(sorted(targets))
                save_args = {
                    "dir": "filtering",
                    "prefix": f"filtering_{targets_str}",
                    "do_checkpoint": stage.get("do_checkpoint", False),
                }
                
                model = run_experiment(
                    stage=stage,
                    model=model,
                    config=run_config,
                    func=do_train,
                    func_args={
                        "data_labels": train_labels,
                        "save_args": save_args,
                    },
                    eval_args={"log": {
                        "target": targets,
                        "finetune": None,
                        "elicited": False,
                    }},
                )[0]

                do_ft_forget = stage.get("ft_forget", False)
                if do_ft_forget:

                    logger.info(f"FILTERING - ADVERSARIAL FT START")

                    model = get_raw_model(model).to("cpu", dtype=torch.bfloat16)

                    for label in targets:

                        ft_forget_model = copy_model(model, device, do_compile)

                        run_experiment(
                            stage=stage,
                            model=ft_forget_model,
                            config=run_config,
                            func=do_finetune,
                            func_args={"data_labels": [label]},
                            eval_args={
                                "data_labels": [label], 
                                "log": {
                                    "target": targets,
                                    "finetune": label,
                                    "elicited": True,
                                }},
                        )

                        del ft_forget_model
                        torch.cuda.empty_cache()
                    
                    # Move model back to device before cleanup
                    model = model.to(device, dtype=torch.bfloat16)

                del model
                torch.cuda.empty_cache()

        # ---------------------- CORE FT AUX ---------------------- #

        if stage_name == "coreftaux":

            logger.info("CORE FT AUX - CORE PHASE START")

            ordered_args = {
                "alpha": stage.get("alpha"),
                "beta": stage.get("beta"),
                "aux_prc": stage.get("aux_prc"),
                "core_prc": stage.get("core_prc"),
                "len_aux": len_aux,
                "len_core": len_core,
            }
            ordered_args = convert_ordered_args(**ordered_args)
            stage.update(ordered_args)

            model = make_model(BaseTransformer, model_config, run_config)

            model, state = run_experiment(
                stage=stage,
                model=model,
                config=run_config,
                func=do_coreftaux,
                func_args={
                    "data_label": "core",
                    "aux_prc": stage["aux_prc"],
                    "core_prc": stage["core_prc"],
                    "do_ft": False,
                },
                eval_args={"log": {
                    "train_label": "core",
                    "target": sorted(aux_labels),
                    "finetune": None,
                    "elicited": False,
                }},
            )

            model = get_raw_model(model).to("cpu", dtype=torch.bfloat16)

            ftaux_targets = ["core"] + sorted(aux_labels)

            for train_label in ftaux_targets:

                ft_model = copy_model(model, device, do_compile)
                ft_state = deepcopy(state)
                targets = sorted(aux_labels - {train_label})

                if train_label != "core":

                    logger.info(f"CORE FT AUX - AUX PHASE: {train_label}")

                    ft_model = run_experiment(
                        stage=stage,
                        model=ft_model,
                        config=run_config,
                        func=do_coreftaux,
                        func_args={
                            "data_label": train_label,
                            "state": ft_state,
                            "aux_prc": stage["aux_prc"],
                            "core_prc": stage["core_prc"],
                            "do_ft": True,
                        },
                        eval_args={"log": {
                            "train_label": train_label,
                            "target": targets,
                            "finetune": None,
                            "elicited": False,
                        }},
                    )[0]

                do_ft_forget = stage.get("ft_forget", False)
                if do_ft_forget:

                    if len(targets) == 0:
                        continue

                    logger.info(f"CORE FT AUX - AUX PHASE: {train_label} - ADVERSARIAL FT START")

                    ft_model = get_raw_model(ft_model).to("cpu", dtype=torch.bfloat16)

                    for label in targets:

                        ft_forget_model = copy_model(ft_model, device, do_compile)

                        run_experiment(
                            stage=stage,
                            model=ft_forget_model,
                            config=run_config,
                            func=do_finetune,
                            func_args={"data_labels": [label]},
                            eval_args={
                                "data_labels": [label],
                                "log": {
                                    "train_label": train_label,
                                    "target": targets,
                                    "finetune": label,
                                    "elicited": True,
                                }},
                        )

                        del ft_forget_model
                        torch.cuda.empty_cache()

                del ft_model
                torch.cuda.empty_cache()

            del model
            torch.cuda.empty_cache()

        # ---------------------- ROUTED ---------------------- #

        if stage_name == "routed":

            logger.info(f"ROUTED START")

            arch = stage.get("arch", "moe") #"demix", "moe", "lora"
            assert arch in ["demix", "moe", "lora"], f"Invalid arch: {arch}"

            expert_dist = stage.get("expert_dist", "add") #"add", "equal_one", "equal_sum", "prc_one", "prc_sum"
            assert expert_dist in ["add", "equal_one", "equal_sum", "prc_one", "prc_sum"], f"Invalid expert_dist: {expert_dist}"

            aux_exp_prc = stage.get("aux_exp_prc", None)
            mlp_dim, aux_dim = get_routed_dims(model_config, run_config, expert_dist, aux_prc=aux_exp_prc)

            ordered = stage.get("ordered", False) #bool

            func = None
            func_args = dict()
            
            if arch == "demix":

                logger.info(f"ROUTED - MLP dim: {mlp_dim}, AUX dim: {aux_dim}")

                model = make_model(
                    DemixTransformer,
                    model_config, 
                    run_config,
                    extra_args={
                        "mlp_dim": mlp_dim,
                        "aux_dim": aux_dim,
                    },
                )
                func = do_demix

                stage_str = f"routed_demix"

            else:

                stage_str = f"routed_{arch}"

                if arch == "moe":

                    logger.info(f"ROUTED - MLP dim: {mlp_dim}, AUX dim: {aux_dim}")

                    model = make_model(
                        MoETransformer, 
                        model_config,
                        run_config,
                        extra_args={
                            "mlp_dim": mlp_dim,
                            "aux_dim": aux_dim,
                        }
                    )

                elif arch == "lora":

                    lora_attn = stage.get("lora_attn", False)
                    lora_mlp = stage.get("lora_mlp", True)
                    assert lora_attn or lora_mlp, "At least one of lora_attn or lora_mlp must be True"

                    lora_rank = stage.get("lora_rank", -1)
                    if lora_rank == -1:
                        lora_rank = calc_lora_rank(model_config, lora_attn, lora_mlp, mlp_dim, aux_dim)

                    logger.info(f"ROUTED - MLP dim: {mlp_dim}, LoRA rank: {lora_rank}")

                    model = make_model(
                        LoRATransformer,
                        model_config,
                        run_config,
                        extra_args={
                            "mlp_dim": mlp_dim,
                            "lora_rank": lora_rank,
                            "lora_attn": lora_attn,
                            "lora_mlp": lora_mlp,
                        }
                    )

                    lora_attn_str = "attn" if lora_attn else "noattn"
                    lora_mlp_str = "mlp" if lora_mlp else "nomlp"
                    lora_rank_str = f"rank-{lora_rank}"
                    stage_str += f"_{lora_attn_str}_{lora_mlp_str}_{lora_rank_str}"

                if ordered:

                    stage_str += "_ordered"

                    ordered_args = {
                        "alpha": stage.get("alpha"),
                        "beta": stage.get("beta"),
                        "aux_prc": stage.get("aux_prc"),
                        "core_prc": stage.get("core_prc"),
                        "len_aux": len_aux,
                        "len_core": len_core,
                    }
                    ordered_args = convert_ordered_args(**ordered_args)
                    stage.update(ordered_args)

                    func = do_routed_ordered
                    func_args = {
                        "ft_aux_prc": stage["aux_prc"], 
                        "core_prc": stage["core_prc"],
                        "equal_compute": stage.get("equal_compute", True),
                    }
                    
                    aux_prc_str = f"aux_prc-{stage['aux_prc']}"
                    core_prc_str = f"core_prc-{stage['core_prc']}"
                    stage_str += f"_{aux_prc_str}_{core_prc_str}"

                else:

                    stage_str += "_unordered"

                    train_arbsub = stage.get("train_arbsub", False)

                    if not train_arbsub:

                        aux_route_prc = stage.get("aux_route_prc", 0.75)
                        robust_prc = stage.get("robust_prc", 0.5)
                        core_prc = stage.get("core_prc", 1.0)

                        func = do_routed_unordered
                        func_args = {
                            "aux_route_prc": aux_route_prc,
                            "robust_prc": robust_prc,
                            "core_prc": core_prc
                        }

                        aux_route_prc_str = f"aux_route_prc-{aux_route_prc}"
                        robust_prc_str = f"robust_prc-{robust_prc}"
                        core_prc_str = f"core_prc-{core_prc}"
                        stage_str += f"_{aux_route_prc_str}_{robust_prc_str}_{core_prc_str}"

                    
                    else:

                        aux_route_prc = stage.get("aux_route_prc", 0.75)
                        core_robust_prc = stage.get("core_robust_prc", 0.5)
                        aux_robust_prc = stage.get("aux_robust_prc", core_robust_prc)

                        func = do_routed_unordered_arbsub
                        func_args = {
                            "aux_route_prc": aux_route_prc,
                            "core_robust_prc": core_robust_prc,
                            "aux_robust_prc": aux_robust_prc
                        }

                        aux_route_prc_str = f"aux_route_prc-{aux_route_prc}"
                        core_robust_prc_str = f"core_robust_prc-{core_robust_prc}"
                        aux_robust_prc_str = f"aux_robust_prc-{aux_robust_prc}"
                        stage_str += f"_{aux_route_prc_str}_{core_robust_prc_str}_{aux_robust_prc_str}_arbsub"
                    
            # Load checkpoint if specified
            checkpoint_path = stage.get("checkpoint")
            if checkpoint_path is not None:
                model, state = restore_state(model, checkpoint_path, device, logger)
                func_args["state"] = state

            #TRAIN
            save_args = {
                "dir": "routed",
                "prefix": stage_str,
                "do_checkpoint": stage.get("do_checkpoint", False),
            }
            func_args["save_args"] = save_args

            if stage.get("do_train", True):
                model = func(
                    model=model,
                    config=run_config,
                    **func_args,
                )[0]


            #ABLATE
            for ablate_labels in forget_targets:

                experts = all_labels - set(ablate_labels)
                ablate_labels = sorted(ablate_labels)

                if arch == "demix":
                    if len(experts) > 1:
                        experts -= {"core"}
                    assert len(experts) == 1, "demix num_experts must be 1"

                experts = sorted(experts)

                logger.info(f"Experts: {experts}")
                logger.info(f"Ablating: {ablate_labels}")

                if stage.get("do_eval", True):
                    do_eval(
                        stage=stage,
                        model=model,
                        config=run_config,
                        expert_labels=experts,
                        log={
                            "target": ablate_labels,
                            "finetune": None,
                            "elicited": False,
                        },
                    )

                #ADVERSARIAL FT
                if stage.get("ft_forget", False):

                    if len(ablate_labels) == 0: continue

                    logger.info(f"ROUTED - ADVERSARIAL FT START")

                    model = get_raw_model(model).to("cpu", dtype=torch.bfloat16)

                    for label in ablate_labels:

                        ft_forget_model = copy_model(model, device, do_compile)

                        run_experiment(
                            stage=stage,
                            model=ft_forget_model,
                            config=run_config,
                            func=do_finetune,
                            func_args={
                                "data_labels": [label],
                                "expert_labels": experts
                            },
                            eval_args={
                                "data_labels": [label],
                                "expert_labels": experts,
                                "log": {
                                    "target": ablate_labels,
                                    "finetune": label,
                                    "elicited": True,
                                }
                            },
                        )

                        del ft_forget_model
                        torch.cuda.empty_cache()
                    
                    # Move model back to device for next ablation iteration
                    model = model.to(device, dtype=torch.bfloat16)

            del model
            torch.cuda.empty_cache()


def run(
    # stage config
    stages: list[dict],

    # model config
    ctx_len: int,
    num_layers: int,
    embed_dim: int,
    mlp_dim: int,

    # run config
    arbsub: bool,
    test_ood: bool,
    data_dirs: list[str],
    aux_labels: list[str],
    core_labels: list[str] | None,
    do_compile: bool,
    seed: int,
    res_dir: str,
    batch_size: int,
    epochs: int,
    lr: float,
    log_level: str,
    core_batch_limit: str | int | None,
    aux_batch_limit: int | float | None,
    lr_schedule: bool,
    do_cleanup_distributed: bool,
    accumulation_steps: int,
    optimize_routed_training: bool,
    timestamp: Optional[str] = None,
    process_id: Optional[int] = None,
) -> None:

    # Setup distributed training if launched with torchrun (idempotent)
    setup_distributed()
    
    try:
        configs = setup(
            # stage config
            stages=stages,

            # model config
            ctx_len=ctx_len,
            num_layers=num_layers,
            embed_dim=embed_dim,
            mlp_dim=mlp_dim,
            
            # run config
            arbsub=arbsub,
            test_ood=test_ood,
            data_dirs=data_dirs,
            aux_labels=aux_labels,
            core_labels=core_labels,
            do_compile=do_compile,
            seed=seed,
            res_dir=res_dir,
            timestamp=timestamp,
            batch_size=batch_size,
            epochs=epochs,
            lr=lr,
            core_batch_limit=core_batch_limit,
            aux_batch_limit=aux_batch_limit,
            log_level=log_level,
            lr_schedule=lr_schedule,
            process_id=process_id,
            accumulation_steps=accumulation_steps,
            optimize_routed_training=optimize_routed_training,
    )

        logger = configs["run_config"].logger
        loaders = configs["run_config"].loaders
        for key, loader in loaders.items():
            if configs["run_config"].test_ood:
                logger.info(f"Loader {key} train: {len(loader['train'])}, test: {len(loader['test'])}, test_ood: {len(loader['test_ood'])}")
            else:
                logger.info(f"Loader {key} train: {len(loader['train'])}, test: {len(loader['test'])}")

        run_experiments(stages, **configs)

        logger = configs["run_config"].logger
        res_dir = configs["run_config"].res_dir
        logger.info("-" * 40)
        logger.info(f"Finished. See {res_dir}")
    
    finally:
        # Clean up distributed training (skip if caller manages lifecycle)
        if do_cleanup_distributed:
            cleanup_distributed()


# --------------------------------------------------------------------------- #
# CLI                                                                         #
# --------------------------------------------------------------------------- #

if __name__ == "__main__":

    # clear cache
    torch.cuda.empty_cache()

    root_dir = Path('src').absolute()
    timestamp = get_timestamp()

    #--- Stories Experiment --- #
    from src.run.orchestrate.config import StoriesBaseArgs
    defaults = deepcopy(StoriesBaseArgs)
    defaults['res_dir'] = root_dir / f"results/stories/results_{timestamp}"
    dataset_metadata = json.load(open(defaults["data_dirs"][0] / "metadata.json"))
    all_labels = sorted(dataset_metadata["all"]["labels"])
    num_labels = 4
    defaults['aux_labels'] = all_labels[:num_labels]
    defaults['seed'] = 42

    '''
    #--- Realistic Experiment --- #
    from src.run.orchestrate.config import RealisticBaseArgs, calc_realistic_model_params
    defaults = deepcopy(RealisticBaseArgs)
    defaults['res_dir'] = root_dir / f"results/realistic/results_{timestamp}"
    defaults['aux_labels'] = ["bigcode", "biology", "nuclear", "cyber"]
    model_params = calc_realistic_model_params(700e6)
    defaults.update(model_params)
    '''

    defaults['do_compile'] = False
    defaults['core_batch_limit'] = 10*44
    defaults['aux_batch_limit'] = 10*4

    defaults['log_level'] = "DEBUG"
    defaults['seed'] = 0
    defaults['stages'] = [
        {"name": "baseline", "ft_forget": False, "do_checkpoint": True},
        {"name": "maxent", "ft_forget": True, "me_alpha_retain": 30},
        {"name": "coreftaux", "ft_forget": False, "core_prc": 0.9, "aux_prc": 0.5},
        {
            "name": "routed", "arch": "lora", "ordered": True, "ft_forget": False, 
            "core_prc": 0.9, "aux_prc": 0.5, "lora_attn": True, "lora_mlp": True,
            "expert_dist": "prc_one", "equal_compute": True,
        },
        {
            "name": "routed", "arch": "moe", "ordered": False, "ft_forget": True, 
            "aux_route_prc": 0.75, "robust_prc": 0.5, "expert_dist": "prc_one",
        },
        {"name": "filtering", "ft_forget": True},
    ]

    run(**defaults)