import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import wandb
import importlib
from utils import *
from dataloader import TimeSeriesLoader

MODEL_DISPATCH = {'USAD':  ('methods.USAD',  'USADAnomalyDetector'),
                  'DAGMM': ('methods.DAGMM', 'DAGMMAnomalyDetector'),
                  'LUAD':  ('methods.LUAD',  'LUADAnomalyDetector'),
                  'lstmAE': ('methods.lstmAE', 'LSTMAEAnomalyDetector'),
                  'lstmVAE': ('methods.lstmVAE', 'LSTMVAEAnomalyDetector'),
                  'OmniAnomaly': ('methods.OmniAnomaly', 'OmniAnomalyDetector'),
                  'DeepSVDD': ('methods.DeepSVDD', 'DeepSVDDAnomalyDetector'),
                  'AnomalyTransformer': ('methods.AnomalyTransformer', 'AnomalyTransformerDetector'),
                  'TimesNet': ('methods.TimesNet', 'TimesNetAnomalyDetector')}

def train():
    """
    Train function called by wandb.agent for each hyperparameter combination.
    """
    wandb.init()
    config = wandb.config
    model_name = config.model_name
    dataset_name = config.dataset

    # Reproducibility
    set_seed(42)

    # Initialize data loader
    loader = TimeSeriesLoader(dataset_name=dataset_name,
                              window_size=config.window_size,
                              step_size=config.step_size,
                              batch_size=config.batch_size)

    input_dim = loader.train_ds.data.shape[1]

    # Dynamically import model class
    mod_name, cls_name = MODEL_DISPATCH[model_name]
    module = importlib.import_module(mod_name)
    Detector = getattr(module, cls_name)

    # Instantiate detector with model-specific args
    if model_name == "USAD":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            latent_dim=config.latent_dim,
                            alpha=config.alpha,
                            beta=config.beta)
    elif model_name == "DAGMM":  # dagmm or others
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            latent_dim=config.latent_dim,
                            lambda_energy=config.lambda_energy,
                            lambda_cov_diag=config.lambda_cov_diag)
    elif model_name == "LUAD":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            latent_dim=config.latent_dim,
                            use_PNF=config.use_PNF,
                            PNF_layers=config.PNF_layers,
                            tcn_levels=config.tcn_levels,
                            kernel_size=config.kernel_size,
                            dropout=config.dropout)
    elif model_name == "lstmAE":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim)
    elif model_name == "lstmVAE":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            latent_dim=config.latent_dim,
                            noise_std=config.noise_std,
                            n_layers=config.n_layers,
                            kld_coef=config.kld_coef)
    elif model_name == "OmniAnomaly":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            latent_dim=config.latent_dim,
                            use_PNF=config.use_PNF,
                            PNF_layers=config.PNF_layers)
    elif model_name == "DeepSVDD":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            dropout_rate=config.dropout_rate)
    elif model_name == "AnomalyTransformer":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            num_layers=config.num_layers)
    elif model_name == "TimesNet":
        detector = Detector(dataloader=loader,
                            input_dim=input_dim,
                            hidden_dim=config.hidden_dim,
                            num_layers=config.num_layers)

    else:
        raise ValueError(f"Model {model_name} is not supported.")

    # Train & evaluate
    detector.fit(epochs=config.epochs, learning_rate=config.learning_rate, data_type="train")

    scores = detector.predict_score(data_type="test")
    metrics = cal_metric(loader.test_ds.labels, scores)
    wandb.log({"AUROC": metrics['aucroc'],
               "AUCPR": metrics['aucpr']})

if __name__ == "__main__":
    sc = SweepConfig("./methods/configs/sweep_space.json")

    # Loop over each model's sweep and launch agents
    for model_name, sweep_cfg in sc.items():
        print(f"Launching sweep for: {model_name}")
        sweep_id = wandb.sweep(sweep_cfg, project="lightweight", entity='lightweight')
        wandb.agent(sweep_id, function=train)