
import pandas as pd
from types import SimpleNamespace
import os
import wandb
import re
from utils.wandb_utils.key  import WANDB_KEY
os.environ["WANDB_API_KEY"] = WANDB_KEY


checkpoints_dic={"halfcheetah-medium-v2":["halfcheetah-medium-25-04-19-03-16-09", "halfcheetah-medium-25-04-19-05-09-36",
                             "halfcheetah-medium-25-04-18-20-31-45"] , # final

                 "halfcheetah-medium-replay-v2": ["halfcheetah-medium-replay-25-05-04-23-41-03","halfcheetah-medium-replay-25-05-05-01-17-25","halfcheetah-medium-replay-25-05-05-03-00-41"],
                "halfcheetah-medium-expert-v2": ["halfcheetah-medium-expert-25-05-06-00-28-32","halfcheetah-medium-expert-25-05-06-00-28-14","halfcheetah-medium-expert-25-05-05-05-52-55"],  # final

                'walker2d-medium-v2':["walker2d-medium-25-04-21-14-22-26", "walker2d-medium-25-04-22-18-04-03", ],
                'walker2d-medium-replay-v2':["walker2d-medium-replay-25-04-19-23-48-56", "walker2d-medium-replay-25-04-20-03-18-17","walker2d-medium-replay-25-04-20-05-42-06"],  # final,
                'walker2d-medium-expert-v2':["walker2d-medium-expert-25-04-27-20-43-51", "walker2d-medium-expert-25-04-27-18-23-56","walker2d-medium-expert-25-04-27-14-06-46"] ,

                "hopper-medium-v2":["hopper-medium-25-04-20-23-52-14", "hopper-medium-25-04-20-21-07-23","hopper-medium-25-04-18-01-09-52"],
                "hopper-medium-replay-v2":["hopper-medium-replay-25-04-21-01-31-23", "hopper-medium-replay-25-04-21-03-06-35",
                         "hopper-medium-replay-25-04-19-05-45-56"],
                "hopper-medium-expert-v2": ["hopper-medium-expert-25-05-06-20-40-09", "hopper-medium-expert-25-05-06-23-46-30",
                         "hopper-medium-expert-25-05-07-01-08-56"],
                 "antmaze-large-diverse-v2":["antmaze-large-diverse-25-04-29-23-05-02", "antmaze-large-diverse-25-05-01-04-57-50",
                          "antmaze-large-diverse-25-05-01-08-08-49", "antmaze-large-diverse-25-04-29-23-03-02",
                          "antmaze-large-diverse-25-05-02-10-25-24"],
                 "antmaze-medium-diverse-v2":["antmaze-medium-diverse-25-04-29-14-51-15",
                          "antmaze-medium-diverse-25-04-29-14-57-00"] ,
                 "antmaze-umaze-diverse-v2":["antmaze-umaze-diverse-25-04-29-22-57-07", "antmaze-umaze-diverse-25-04-30-04-50-53",
                          "antmaze-umaze-diverse-25-04-30-07-52-11", ],
                 "antmaze-umaze-v2":["antmaze-umaze-25-04-29-22-43-06"]

                  }



# antmaze_ld_checkpoints = ["antmaze-large-diverse-25-04-29-23-05-02", "antmaze-large-diverse-25-05-01-04-57-50",
#                           "antmaze-large-diverse-25-05-01-08-08-49", "antmaze-large-diverse-25-04-29-23-03-02",
#                           "antmaze-large-diverse-25-05-02-10-25-24"]  # final
# antmaze_md_checkpoints = ["antmaze-medium-diverse-25-04-29-14-51-15",
#                           "antmaze-medium-diverse-25-04-29-14-57-00"]  # final
# antmaze_ud_checkpoints = ["antmaze-umaze-diverse-25-04-29-22-57-07", "antmaze-umaze-diverse-25-04-30-04-50-53",
#                           "antmaze-umaze-diverse-25-04-30-07-52-11", ]  # final
# antmaze_um_checkpoints = ["antmaze-umaze-25-05-01-10-58-40", "antmaze-umaze-25-05-01-07-59-40",
#                           "antmaze-umaze-25-05-01-05-01-05", ]  # final

# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the root directory
project_dir = os.path.dirname(os.path.abspath(__file__))


def find_latest_checkpoint(directory):
    files = [f for f in os.listdir(directory) if re.match(r'epoch_\d+.*', f)]
    return os.path.join(directory, max(files, key=lambda f: int(re.search(r'epoch_(\d+)', f).group(1)))) if files else None


def get_args(run_names):

    args=[]
    run_infos= get_wandb_run_info()
    for run_name in run_names:
        if run_name not in run_infos.keys():
            raise ValueError("No saved convformer_models found for :", run_name)
        else:
            args.append(SimpleNamespace(**run_infos[run_name]["config"]))

    return args



def get_wandb_run_info():
    # Authenticate and specify your project
    api = wandb.Api()
    project = "ATDT"  # Replace with your project name
    entity = "Niloofar"  # Replace with your W&B entity (user or team)

    # Fetch all runs in the project
    runs = api.runs(f"{entity}/{project}")

    # Extract run data
    runs_data = {}
    for run in runs:
        run_info = {
            "name": run.name,
            "id": run.id,
            "state": run.state,
            "config": run.config,
            "summary": run.summary._json_dict,
            "tags": run.tags,
            "created_at": run.created_at,
        }
        runs_data[run_info["name"]] = run_info


    # # Convert to DataFrame
    # df = pd.DataFrame(runs_data)
    # df = df[df.state == 'finished']
    # df.head()
    # config_df = pd.DataFrame(df.config.to_list())
    # summary_df = pd.DataFrame(df.summary.to_list())
    # df = pd.concat([config_df, summary_df], axis=1)
    # df = df.loc[:, ~df.columns.duplicated()]
    return runs_data

def load_wandb_run(run_name):
    # Authenticate and specify your project
    api = wandb.Api()
    project = "DecisionTransformer2"  # Replace with your project name
    entity = "Niloofar"  # Replace with your W&B entity (user or team)

    # Fetch all runs in the project
    runs = api.runs(f"{entity}/{project}")

    # Extract run data
    for run in runs:
        if run.name == run_name:
            return run

    raise ValueError("No saved convformer_models found for :", run_name)

if __name__ == "__main__":
    for env_name in checkpoints_dic.keys():
        checkpoints = checkpoints_dic[env_name]
        print(env_name)
        for checkpoint in checkpoints:
            # run = load_wandb_run(checkpoint)
            args= get_args([checkpoint])[0]
            if args.seed==42:
                print(checkpoint)
                cmd=(f"python3 scripts/AQDT.py"
                        f" --env {args.env} --dataset {args.dataset}"
                        f" --batch_size {args.batch_size} --num_steps_per_iter {args.num_steps_per_iter}"
                        f" --grad_norm {args.grad_norm} --step_start_critic {args.step_start_critic}"
                        f" --max_iters {args.max_iters} --learning_rate {args.learning_rate}"
                        f" --critic_lr {args.critic_lr} --seed {args.seed}"
                        f" --train_critic_every_n_epoch {args.train_critic_every_n_epoch} --q_scale {args.q_scale}"
                        f" --k_rewards {args.k_rewards} --lr_decay {args.lr_decay}"
                        + (f" --q_min {args.q_min}" if hasattr(args, "q_min") else "")
                        + (" --lr_decay" if args.lr_decay else ""))

                print(cmd)

        print("-"*20)

