import argparse
import importlib
import importlib.util
import os
import pandas as pd
# 新加的
import numpy as np
import random
import importlib
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# 新加结束

import lightning.pytorch as L
import numpy
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

from easytsf.util.util import cal_conf_hash
from easytsf.util.util import load_module_from_path
import time

from ConfigSpace import Configuration, ConfigurationSpace
from smac import HyperparameterOptimizationFacade, Scenario

# 次数可学习

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import optuna_opti


def objective_epoch(conf, Bn):
    global result
    # 更新配置
    conf["Bn"] = Bn
    # 设置随机种子
    L.seed_everything(conf["seed"])
    save_dir = os.path.join(conf["save_root"], f'{conf["model_name"]}_{conf["dataset_name"]}_{Bn}')

    # 根据是否使用 wandb 选择记录器
    if "use_wandb" in conf and conf["use_wandb"]:
        run_logger = WandbLogger(save_dir=save_dir, name=conf["conf_hash"], version=f'seed_{conf["seed"]}')
    else:
        run_logger = CSVLogger(save_dir=save_dir, name=conf["conf_hash"], version=f'seed_{conf["seed"]}')

    # 设置实验目录
    conf["exp_dir"] = os.path.join(save_dir, conf["conf_hash"], f'seed_{conf["seed"]}')
    conf["max_epochs"] = 2
    # 设置回调函数
    callbacks = [
        ModelCheckpoint(
            monitor=conf["val_metric"],
            mode="min",
            save_top_k=1,
            save_last=False,
            every_n_epochs=1,
        ),
        EarlyStopping(
            monitor=conf["val_metric"],
            mode='min',
            patience=conf["es_patience"],
        ),
        LearningRateMonitor(logging_interval="epoch"),
    ]

    # 若使用 Ray Tune，则添加相关回调
    if "use_ray" in conf and conf["use_ray"]:
        callbacks.append(TuneReportCheckpointCallback(
            {conf["val_metric"]: conf["val_metric"]}, save_checkpoints=False, on="validation_end"))

    # 创建 Trainer
    trainer = L.Trainer(
        accelerator="cuda",
        devices='auto',
        precision=conf.get("precision", "32-true"),
        logger=run_logger,
        callbacks=callbacks,
        max_epochs=conf["max_epochs"],
        gradient_clip_algorithm=conf.get("gradient_clip_algorithm", "norm"),
        gradient_clip_val=conf["gradient_clip_val"],
        default_root_dir=conf["save_root"],
    )

    # 导入数据接口和模型运行器
    DataIF = importlib.import_module(f'easytsf.data.{conf["dm"]}').DataInterface
    data_module = DataIF(**conf)
    ModelRunner = importlib.import_module(f'easytsf.runner.{conf["runner"]}').Runner

    # 创建模型并进行训练
    model = ModelRunner(**conf)
    print("============================================================")
    print(data_module)
    trainer.fit(model=model, datamodule=data_module)

    # 测试模型并返回目标值
    rst = trainer.test(model, datamodule=data_module, ckpt_path='best')

    # 获取当前试验的 MSE 结果
    mse = rst[0]['test/mse']
    mae = rst[0]['test/mae']
    print(f"Bn={Bn}, Test MSE: {mse}, Test MAE: {mae}")
    if str(conf["Bn"]) not in result:
        result[str(conf["Bn"])] = {"mse": mse, "mae": mae}
    else:
        result[str(conf["Bn"])]["mse"] = mse
        result[str(conf["Bn"])]["mae"] = mae
    # result.append({
    #     "Bn": Bn,
    #     "mse": mse,
    #     "mae": mae
    # })
    return mse


def load_config(exp_conf_path, args):
    # 加载 exp_conf
    exp_conf = load_module_from_path("exp_conf", exp_conf_path).exp_conf
    exp_conf["layer_type"] = args.layer_type  # 修改
    # 加载 task_conf
    task_conf_module = importlib.import_module('config.base_conf.{}'.format(exp_conf['task']))
    task_conf = task_conf_module.task_conf
    task_conf["pred_len"] = args.pred_len  # 新增
    # 加载 data_conf
    data_conf_module = importlib.import_module('config.base_conf.datasets')
    data_conf = eval('data_conf_module.{}_conf'.format(args.dataset_name))

    # conf 融合，参数优先级: exp_conf > task_conf = data_conf
    fused_conf = {**task_conf, **data_conf}
    fused_conf.update(exp_conf)

    return fused_conf


if __name__ == '__main__':
    time0 = time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", default="config/basicSetting.py", type=str)
    # 用什么模型跑什么数据集，在path里改路径；具体在/data0/Wendy/code/KAN/LRKAN/config/setting.py
    parser.add_argument("-d", "--data_root", default="/media/longin/data/data/Wendy/code/learn/dataset", type=str, help="data root")
    # 数据集路径，default=路径
    parser.add_argument("-s", "--save_root", default="/media/longin/data/data/Wendy/code/learn/save", help="save root")
    # 保存路径是save
    parser.add_argument("--devices", default=0, type=int, help="The devices to use, detail rules is show in README")
    # 用哪个卡（0还是1）
    parser.add_argument("--use_wandb", default=0, type=int, help="use wandb")
    # default=1是用网站看loss
    parser.add_argument("--seed", type=int, default=6, help="seed")

    # 下面是新传参的
    parser.add_argument("--dataset_name", type=str, default="ECL", help="dataset_name")
    parser.add_argument("--layer_type", type=str, default="BnKAN", help="layer_type")
    parser.add_argument("--pred_len", type=int, default="96", help="length of prediction")
    # import pdb; pdb.set_trace()
    args = parser.parse_args()
    conf = load_config(args.config, args)

    training_conf = {
        "seed": int(args.seed),
        "data_root": args.data_root,
        "save_root": args.save_root,
        "devices": args.devices,
        "use_wandb": args.use_wandb,
    }
    if training_conf is not None:
        for k, v in training_conf.items():
            conf[k] = v
    conf['conf_hash'] = cal_conf_hash(conf, hash_len=10)
    conf["device"] = conf["devices"]


    def objective(config: Configuration, seed: int):
        Bn = config["Bn"]
        return objective_epoch(conf, Bn)
    result = {}
    configspace = ConfigurationSpace({"Bn": (1, 10)})
    # Scenario object specifying the optimization environment
    scenario = Scenario(configspace, deterministic=True, n_trials=200, seed=-1)
    # Use SMAC to find the best configuration/hyperparameters
    smac = HyperparameterOptimizationFacade(scenario, objective)
    incumbent = smac.optimize()

    print("\nOptimization finished. Best configuration found:")
    best_config = dict(incumbent)


    # 遍历并打印result中的每个元素
    # for idx, res in enumerate(result, start=1):
    #     print(f"Result {idx}:")
    #     for key, value in res.items():
    #         print(f"  {key}: {value}")
    #     print()
    time_end = time.time()
    mse = result[str(best_config['Bn'])]["mse"]
    mae = result[str(best_config['Bn'])]["mae"]
    print(f"Best Bn: {best_config['Bn']}", "  mse: ", mse, "   mae:", mae)
    save_path = conf['save_root']
    results_file = os.path.join(save_path, "pred_len_results.csv")

    os.makedirs(save_path, exist_ok=True)

    try:
        results_df = pd.read_csv(results_file)
    except FileNotFoundError:
        results_df = pd.DataFrame(columns=['H', 'MAE', 'MSE'])

    new_row = pd.DataFrame({'H': [conf["pred_len"]], 'MAE': [mae], 'MSE': [mse]})

    results_df = pd.concat([results_df, new_row], ignore_index=True)

    results_df.to_csv(results_file, index=False)
    print("time cost", time_end - time0, "s")

