# Entry-point: parse args, load YAMLs, merge hyperparams, dispatch experiment
import argparse
import os
import yaml
import torch

from data.loader import DatasetLoader
from energy_monitor import EnergyMonitor
from experiment import Experiment, DEFAULT_HP
import utils

import sys
sys.path.append("/home/filip/code/Thesis-code/TabZilla")

def load_yaml(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def merge_dict(base, override):
    for k, v in (override or {}).items():
        base[k] = v
    return base

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--experiment", required=True)
    p.add_argument("--task",       type=int, required=True)
    p.add_argument("--device",     choices=["cpu","cuda"], default="cpu")
    p.add_argument("--min_segments", type=int, default=5)
    p.add_argument("--no_tuning",  action="store_true")
    p.add_argument("--seed", type=int, default=42, help="Random seed")
    args = p.parse_args()

    utils.set_seed(args.seed)
    utils.get_logger()

    # 1) load experiment YAML
    exp_cfg   = load_yaml(args.experiment)
    raw_model = exp_cfg["model"]
    key       = raw_model.lower()
    mode      = exp_cfg["mode"]
    hp        = exp_cfg.get("hyperparameters", {})
        
    tuning_f = f"tuning/task_{args.task}_hyperparams.yml"
    print(f"Checking for tuning file: {tuning_f}")
    if not args.no_tuning and os.path.isfile(tuning_f):
        print(f"Found tuning file! Merging...")
        merge_dict(hp, load_yaml(tuning_f))
    else:
        print("Tuning file not found or --no_tuning specified.")
        
    tuning_applied = (not args.no_tuning and os.path.isfile(tuning_f)) # True if tuning file was loaded

    # 3) fill defaults
    for k,v in DEFAULT_HP.items():
        hp.setdefault(k, v)

    # 4) loader
    ds_name = f"openml_task_{args.task}"
    loader  = DatasetLoader(ds_name, min_segments=args.min_segments)
    model_args = argparse.Namespace(
        num_features=loader.input_size,
        num_classes=loader.num_classes,
        objective="classification",
        device=args.device,
        use_gpu=args.device == "cuda",
        gpu_ids=[0],
        epochs=hp.get("epochs", 500),
        logging_period=100,
        early_stopping_rounds=10,
        data_parallel=False,
        batch_size=hp.get("batch_size", 256),
        val_batch_size=hp.get("batch_size", 256),
        cat_idx=[],
        cat_dims=[],
        model_name=key,
        dataset=str(args.task)
    )

    # 5) pick the right class
    if   key == "mlp":
        from models.mlp import MLP as ModelClass
    elif key == "imlp":
        from models.imlp import IncrementalMLP as ModelClass
    elif key == "tabm":
        from models.baseline_NNs.tabm import TabMClassifier as ModelClass
        model = ModelClass(params=hp, args=model_args)
    elif key == "tabr":
        from models.baseline_NNs.tabr import TabRClassifier as ModelClass
        model = ModelClass(params=hp, args=model_args)
    elif key == "modernnca":
        from models.baseline_NNs.modernnca import ModernNCAClassifier as ModelClass
        model = ModelClass(params=hp, args=model_args)
    elif key == "realmlp":
        from models.baseline_NNs.realmlp import RealMLPClassifier as ModelClass
        model = ModelClass(params=hp, args=model_args)
    elif key == "tabpfnv2":
        from models.baseline_NNs.tabpfnv2 import TabPFNv2Wrapper as ModelClass
        model = ModelClass(params=hp, args=model_args)
    elif key == "tabpfn":
        from TabZilla.models.tabpfn import TabPFNModel as ModelClass
    elif key == "xgboost":
        from TabZilla.models.tree_models import XGBoost as ModelClass
    elif key == "catboost":
        from TabZilla.models.tree_models import CatBoost as ModelClass
    elif key == "lightgbm":
        from TabZilla.models.tree_models import LightGBM as ModelClass
    elif key == "linearmodel":
        from TabZilla.models.baseline_models import LinearModel as ModelClass
    elif key == "knn":
        from TabZilla.models.baseline_models import KNN as ModelClass
    elif key == "svm":
        from TabZilla.models.baseline_models import SVM as ModelClass
    elif key == "decisiontree":
        from TabZilla.models.baseline_models import DecisionTree as ModelClass
    elif key == "randomforest":
        from TabZilla.models.baseline_models import RandomForest as ModelClass
    elif key == "saint":
        from TabZilla.models.saint import SAINT as ModelClass
    elif key == "resnet":
        from TabZilla.models.rtdl import rtdl_ResNet as ModelClass
    elif key == "fttransformer":
        from TabZilla.models.rtdl import rtdl_FTTransformer as ModelClass
    elif key == "danet":
        from TabZilla.models.danet import DANet as ModelClass
    elif key == "stg":
        from TabZilla.models.stochastic_gates import STG as ModelClass
    elif key == "tabnet":
        from TabZilla.models.tabnet import TabNet as ModelClass
    elif key == "node":
        from TabZilla.models.node import NODE as ModelClass
    elif key == "vime":
        from TabZilla.models.vime import VIME as ModelClass
    else:
        raise ValueError(f"Unknown model {raw_model}")

    # 6) instantiate model
    device = torch.device(args.device)
    if key == "mlp":
        model = ModelClass(loader.input_size, loader.num_classes).to(device)
    elif key == "imlp":
        kwargs = {
                "use_attention": hp["use_attention"],
                "window_size":  hp["window_size"]
            }
        model = ModelClass(loader.input_size, loader.num_classes, **kwargs).to(device)
    elif key == "tabpfn":
    # TabPFN requires no extra args
        args.no_tuning = True
        tabpfn_params = {
            "ignore_pretraining_limits": True  # Remove the colon typo
        }
        model = ModelClass(params=tabpfn_params, args=model_args)
    elif key in ["xgboost", "catboost", "lightgbm"]:
        # XGBoost, CatBoost, LightGBM require hyperparameters
        model = ModelClass(params=hp, args=model_args)
    elif key in ["linearmodel", "knn", "svm", "decisiontree", "randomforest"]:
        model = ModelClass(params=hp, args=model_args)
    elif key in ["saint", "resnet", "vime"]:
        model = ModelClass(params=hp, args=model_args)
    elif key == "fttransformer":
        #force batch size to only be 16
        model_args.batch_size = 1
        
        model = ModelClass(params=hp, args=model_args)
        
    elif key == "danet":
        # if model_args.batch_size == 128:
        #     model_args.batch_size = 127
        #     model_args.val_batch_size = 127
        danet_params = {}
        for param in ["layer", "base_outdim", "k", "drop_rate", "optimizer_params"]:
            if param in hp:
                danet_params[param] = hp[param]
        
        model = ModelClass(params=danet_params, args=model_args)
    elif key == "stg":
        # Create filtered params for STG
        stg_params = {}
        for param in ["hidden_dims", "weight_decay"]:
            if param in hp:
                stg_params[param] = hp[param]
        
        model = ModelClass(params=stg_params, args=model_args)
    elif key == "tabnet":
        tabnet_params = {}
        for param in ["n_d", "n_steps", "gamma", "optimizer_fn", "optimizer_params"]:
            if param in hp:
                tabnet_params[param] = hp[param]
        
        model = ModelClass(params=tabnet_params, args=model_args)
    elif key == "node":
        node_params = {}
        for param in ["total_tree_count", "num_layers", "tree_depth", "tree_output_dim"]:
            if param in hp:
                node_params[param] = hp[param]
        
        model = ModelClass(params=node_params, args=model_args)

    # 7) energy monitor - Modified to support ablation output directory
    # Check for ablation output directory from environment variable
    ablation_output_dir = os.environ.get('ABLATION_OUTPUT_DIR')
    
    if ablation_output_dir:
        # For ablation study, use the custom directory structure
        energy_output_dir = os.path.join(ablation_output_dir, "energy_logs", key, str(args.task))
        print(f"Using ablation output directory: {energy_output_dir}")
    else:
        # Default behavior
        energy_output_dir = os.path.join("results", "energy_logs", key, str(args.task))
    
    mon = EnergyMonitor(output_dir=energy_output_dir)

    # 8) run - Pass ablation info to experiment
    exp = Experiment(
        loader=loader,
        model=model,
        monitor=mon,
        mode=mode,
        hyperparams=hp,
        device=args.device,
        model_key=key,
        task_id=args.task,
        tuning_applied=tuning_applied,
        seed=args.seed,
        ablation_output_dir=ablation_output_dir  # Pass this to Experiment
    )
    exp.run()

if __name__ == "__main__":
    main()