import gc
import os
import random

import numpy as np
import torch

from exp.exp_anomaly_detection import Exp_Anomaly_Detection
from exp.exp_classification import Exp_Classification
from exp.exp_imputation import Exp_Imputation
from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
from utils.config import get_args
from utils.metrics import save_results


def set_seed(seed: int = 42):
    """
    Set random seeds for reproducibility across random, numpy, and torch.

    Args:
        seed (int): The seed value to set.
        deterministic (bool): Whether to enforce deterministic behavior in PyTorch.
    """
    # Python & NumPy
    random.seed(seed)
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False

    torch.backends.cudnn.deterministic = True



def get_exp(args):
    if args.task_name == "long_term_forecast":
        return Exp_Long_Term_Forecast(args)
    if args.task_name == "imputation":
        return Exp_Imputation(args)
    if args.task_name == "anomaly_detection":
        return Exp_Anomaly_Detection(args)
    if args.task_name == "classification":
        return Exp_Classification(args)

    return Exp_Long_Term_Forecast(args)


def clean(ckpt=None):
    """
    Clean up the cache and remove checkpoints.
    """
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    if ckpt and os.path.exists(ckpt):
        os.remove(ckpt)
        parent_dir = os.path.dirname(ckpt)
        if not os.listdir(parent_dir):
            os.rmdir(parent_dir)


def get_setting(args):
    setting = f"{args.task_name}_{args.model_id}_{args.model}_attndropout{args.attn_dropout}_hiddendim{args.hidden_dim}_mlp-ratio{args.mlp_ratio}_n-depth{args.n_depth}_n-emb{args.n_emb}_n-heads{args.n_heads}_epochs{args.train_epochs}_batch{args.batch_size}_lr{args.learning_rate}"

    if "diffkanformer" in args.model.lower():
        setting += f"_timesteps{args.timesteps}_{args.des}_use_cond{args.use_cond}_use_tphi{args.use_tphi}_normalize{args.normalize}_tphi-loss{args.tphi_loss}"

    if args.task_name == "classification":
        setting += f"_classifier{args.classifier}"

    return setting


if __name__ == "__main__":
    set_seed(42)
    args = get_args()
    setting = get_setting(args)

    import json

    filename = f"results/{args.filename or args.task_name}_results.json"
    if os.path.exists(filename) and args.task_name != "classification":
        resultdict = json.load(open(filename))


    if args.wandb:
        import wandb

        wandb.init(
            project="diffkanformer",
            name=setting,
            config=vars(args),
            reinit=True,
        )
    exp = get_exp(args)

    if args.is_training:
        print(f">>>>>>>TRAINING : \n{setting}\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        exp.train(setting)

    print(f">>>>>>>TESTING : \n{setting}\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
    ckpt = exp.test(setting)


    print("-------------------------------------------------")
