import argparse
import gym
import numpy as np
import os
import torch
import json
import time
import shutil
import pickle
import yaml
import d4rl
from dataset.d4rl_dataset_fetch import download_dataset
from utils import utils
from utils.helpers import Data_Sampler
from utils.logger import logger, setup_logger
from agents.diffusion_ql import Diffusion_QL
from agents.flow_ql import Flow_QL
from agents.mmd_ql import MoMa_QL
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


available_agents = {
    'dql': Diffusion_QL,
    'mmd_ql': MoMa_QL,
    'flow_ql': Flow_QL
}

def train_agent(output_dir, writer, args):
    # Process device first
    if torch.cuda.is_available() and hasattr(args, 'device'):
        if isinstance(args.device, int):
            args.device = f"cuda:{args.device}"
        elif args.device == "cuda":
            args.device = "cuda"
        else:
            args.device = args.device  # Use as is if it's already a valid device string
    else:
        args.device = "cpu"
    
    # Load buffer
    # dataset = d4rl.qlearning_dataset(env)
    if not os.path.exists(f'dataset/{args.env_name}.pkl'):
        download_dataset(args.env_name)
        
    with open(f'dataset/{args.env_name}.pkl', 'rb') as f:
        dataset = pickle.load(f)
    
    # turn the model_args to a dict
    model_args = {}
    for item in args.model_args:
        model_args.update(item)
    data_sampler = Data_Sampler(dataset, args.device, model_args['reward_tune'])
    utils.print_banner('Loaded buffer')
    
    # 添加必需的参数到 model_args
    model_args.update({
        'state_dim': args.state_dim,
        'action_dim': args.action_dim,
        'max_action': args.max_action,
        'device': args.device
    })
    
    # 处理参数名映射 (如果需要的话)
    if 'gn' in model_args:
        model_args['grad_norm'] = model_args.pop('gn')  # gn -> grad_norm
    
    agent = available_agents[args.model](**model_args)

    early_stop = False
    stop_check = utils.EarlyStopping(tolerance=1, min_delta=0.)

    # Load previous evaluations if exist (to keep best-model selection consistent when resuming)
    evaluations = []
    eval_file_path = os.path.join(output_dir, "eval.npy")
    if os.path.exists(eval_file_path):
        try:
            loaded_evals = np.load(eval_file_path, allow_pickle=True)
            # ensure python list
            evaluations = loaded_evals.tolist() if hasattr(loaded_evals, 'tolist') else list(loaded_evals)
        except Exception:
            # if loading fails, start fresh
            evaluations = []

    # Default start
    training_iters = 0
    max_timesteps = args.num_epochs * args.num_steps_per_epoch
    metric = 100.
    utils.print_banner(f"Training Start", separator="*", num_star=90)
    start_time = time.time()
    
    # load the weights if needed
    if args.load_weights:
        if args.load_epoch is not None:
            agent.load_model(output_dir, id=args.load_epoch)
            utils.print_banner(f"Loaded weights from epoch {args.load_epoch}")
            # resume training iters and agent inner step to keep global_step consistent
            try:
                training_iters = int(args.load_epoch) * int(args.num_steps_per_epoch)train.py
            except Exception:
                training_iters = 0
            # sync agent.step used for TB logging and EMA scheduling
            if hasattr(agent, 'step'):
                agent.step = int(training_iters)
        else:
            agent.load_model(output_dir)
            utils.print_banner("Loaded weights without epoch ID")

    
    while (training_iters < max_timesteps) and (not early_stop):
        # print the training_iters
        # print(f"Training iters: {training_iters}")
        
        iterations = int(args.eval_freq * args.num_steps_per_epoch)
        if training_iters > args.num_steps_per_epoch * args.num_epochs * 0.2:
            use_grad = True
        else:
            use_grad = False
        use_grad = False  # for BC
        loss_metric = agent.train(  
            data_sampler,
            iterations=iterations,
            batch_size=args.batch_size,
            log_writer=writer,
            use_grad=use_grad
        )
        
        training_iters += iterations
        curr_epoch = int(training_iters // int(args.num_steps_per_epoch))
        curr_time = time.time()

        bc_loss = np.mean(loss_metric['bc_loss'])
        ql_loss = np.mean(loss_metric['ql_loss'])
        actor_loss = np.mean(loss_metric['actor_loss'])
        critic_loss = np.mean(loss_metric['critic_loss'])
        used_time = curr_time - start_time

        # Logging
        utils.print_banner(f"Train step: {training_iters}", separator="*", num_star=90)
        logger.record_tabular('Trained Epochs', curr_epoch)
        logger.record_tabular('BC Loss', bc_loss)
        logger.record_tabular('QL Loss', ql_loss)
        logger.record_tabular('Actor Loss', actor_loss)
        logger.record_tabular('Critic Loss', critic_loss)
        logger.record_tabular('Time', used_time)

        writer.add_scalar(f"charts/time", used_time, training_iters)


        # Evaluation
        eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(
            agent,
            args.env_name,
            args.seed,
            eval_episodes=args.eval_episodes,
            cfg_scale=model_args.get('cfg_scale', None)
        )
        bc_loss = np.mean(loss_metric['bc_loss'])
        ql_loss = np.mean(loss_metric['ql_loss'])
        actor_loss = np.mean(loss_metric['actor_loss'])
        critic_loss = np.mean(loss_metric['critic_loss'])

        evaluations.append([eval_res, eval_res_std, eval_norm_res, eval_norm_res_std,
                            bc_loss, ql_loss, actor_loss, critic_loss, curr_epoch])
        np.save(os.path.join(output_dir, "eval"), evaluations)
        logger.record_tabular('Average Episodic Reward', eval_res)
        logger.record_tabular('Average Episodic N-Reward', eval_norm_res)
        logger.dump_tabular()

        writer.add_scalar(f"eval_charts/bc_loss", bc_loss, training_iters)
        writer.add_scalar(f"eval_charts/ql_loss", ql_loss, training_iters)
        writer.add_scalar(f"eval_charts/actor_loss", actor_loss, training_iters)
        writer.add_scalar(f"eval_charts/critic_loss", critic_loss, training_iters)
        writer.add_scalar(f"eval_charts/eval_reward", eval_res, training_iters)
        writer.add_scalar(f"eval_charts/eval_reward_std", eval_res_std, training_iters)
        writer.add_scalar(f"eval_charts/eval_norm_reward", eval_norm_res, training_iters)
        writer.add_scalar(f"eval_charts/eval_norm_reward_std", eval_norm_res_std, training_iters)

        if args.early_stop:
            early_stop = stop_check(metric, bc_loss)

        metric = bc_loss

        if args.save_best_model:
            agent.save_model(output_dir, curr_epoch)

    # Model Selection: online （based on normalized score） or offline (based on bc loss)
    scores = np.array(evaluations)
    # online ms
    online_best_id = np.argmax(scores[:, 2])
    # offline ms
    bc_loss = scores[:, 4]
    top_k = min(len(bc_loss) - 1, args.top_k)
    offline_best_id = np.argsort(bc_loss)[top_k]

    if args.ms == 'online':
        best_res = {'model selection': args.ms, 'epoch': scores[online_best_id, -1],
                    'best normalized score avg': scores[online_best_id, 2],
                    'best normalized score std': scores[online_best_id, 3],
                    'best raw score avg': scores[online_best_id, 0],
                    'best raw score std': scores[online_best_id, 1]}
        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))
    elif args.ms == 'offline':
        best_res = {'model selection': args.ms, 'epoch': scores[offline_best_id][-1],
                    'best normalized score avg': scores[offline_best_id][2],
                    'best normalized score std': scores[offline_best_id][3],
                    'best raw score avg': scores[offline_best_id][0],
                    'best raw score std': scores[offline_best_id][1]}

        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))

    writer.close()

    if args.save_best_model:
        # Clean up other pth files except for the selected ones
        # Create a dictionary of old and new filenames
        mapping = {
            f"actor_{int(scores[online_best_id, -1])}.pth": f"actor_online.pth",
            f"critic_{int(scores[online_best_id, -1])}.pth": f"critic_online.pth",
            f"actor_{int(scores[offline_best_id, -1])}.pth": f"actor_offline.pth",
            f"critic_{int(scores[offline_best_id, -1])}.pth": f"critic_offline.pth",
        }

        if online_best_id == offline_best_id:
            for key in mapping.keys():
                shutil.copyfile(os.path.join(output_dir, key), os.path.join(output_dir, mapping[key]))
            for file_name in os.listdir(output_dir):
                # Check if the file is in the mapping dictionary
                # only for file end with .pth
                if file_name.endswith(".pth"):
                    if file_name not in mapping.values():
                        # Delete any files not in the mapping dictionary
                        os.remove(os.path.join(output_dir, file_name))
        else:
            for file_name in os.listdir(output_dir):
                # Check if the file is in the mapping dictionary
                # only for file end with .pth
                if file_name.endswith(".pth"):
                    if file_name in mapping.keys():
                        # Rename the old file name to the new file name
                        os.rename(os.path.join(output_dir, file_name), os.path.join(output_dir, mapping[file_name]))
                    else:
                        if file_name not in mapping.values():
                            # Delete any files not in the mapping dictionary
                            os.remove(os.path.join(output_dir, file_name))


# -----------------------------------------------------------------------------#
# eval policy -----------------------------#
# -----------------------------------------------------------------------------#

# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10, cfg_scale=None):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    scores = []
    for i in range(eval_episodes):
        traj_return = 0.
        state, done = eval_env.reset(), False
        while not done:
            # Apply CFG via cfg_scale if the policy supports it
            if cfg_scale is not None:
                action = policy.sample_action(np.array(state), cfg_scale=cfg_scale)
            else:
                action = policy.sample_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            traj_return += reward
        scores.append(traj_return)

    avg_reward = np.mean(scores)
    std_reward = np.std(scores)

    normalized_scores = [eval_env.get_normalized_score(s) for s in scores]
    avg_norm_score = eval_env.get_normalized_score(avg_reward)
    std_norm_score = np.std(normalized_scores)
    
    utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}")
    return avg_reward, std_reward, avg_norm_score, std_norm_score


if __name__ == "__main__":
    #import os
    #from multiprocessing import cpu_count 
    #cpu_num = cpu_count()
    #cpu_use = 4
    #cur_pid = os.getpid()
    #os.sched_setaffinity(cur_pid, list(range(cpu_num))[:cpu_use])
    #print(f"set the max number of cpu used to {cpu_use}")
    
    parser = argparse.ArgumentParser()
    # wrap all the arguments in a yaml file
    parser.add_argument("--config_path", type=str, default="./configs/halfcheetah-medium-v2.yaml") # including hyper-parameters
    # load the trained weights
    parser.add_argument("--load_weights", action='store_true', help="Load trained weights")
    # load weights from specific epoch
    parser.add_argument("--load_epoch", type=int, default=None, help="Load weights from specific epoch (e.g., 500). Automatically enables --load_weights")
    args = parser.parse_args()
    
    # Auto-enable load_weights if load_epoch is specified
    if args.load_epoch is not None:
        args.load_weights = True
    
    if args.config_path is not None and os.path.exists(args.config_path):
        with open(args.config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Merge config into args, config values will overwrite command-line values
        for key, value in config.items():
            setattr(args, key, value)
    
    else:
        raise ValueError(f"Config file {args.config_path} not found")
    
    
    
    args.output_dir = f'{args.dir}'
    
    # 设置默认的 ms 和 top_k 参数（如果配置文件中没有指定）
    if not hasattr(args, 'ms'):
        args.ms = 'offline'  # 默认使用 offline 模式
    if not hasattr(args, 'top_k'):
        args.top_k = 0  # 默认 top_k = 1
    
    # 构建文件名，包含 ms 信息
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_name = f"{timestamp}|{args.model}"#{args.env_name}|{args.exp}||ms-{args.ms}"
    if args.ms == 'offline':
        file_name += f'|k-{args.top_k}'
    file_name += f'|{args.model_args[-3]}|{args.model_args[-1]}|{args.model_args[-2]}'
    file_name += f'|{args.model_args[0]}|{args.model_args[1]}'
    
    results_dir = os.path.join(args.output_dir, file_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    utils.print_banner(f"Saving location: {results_dir}")
    
    writer = SummaryWriter(log_dir=results_dir)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    
    
    # Setup Logging
    
    if os.path.exists(os.path.join(results_dir, 'variant.json')):
        os.remove(os.path.join(results_dir, 'variant.json'))
    
    # print(gym.envs.registry.env_specs.keys())
    env = gym.make(args.env_name)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])
    
    variant = vars(args)
    variant.update(state_dim=state_dim)
    variant.update(action_dim=action_dim)
    variant.update(max_action=max_action)
    
    setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir)
    utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}")

    train_agent(results_dir,
                writer,
                args)