import argparse
import gym
import numpy as np
import os
import torch
import json
import time
import shutil
import pickle
import yaml
import d4rl
from dataset.d4rl_dataset_fetch import download_dataset
from utils import utils
from utils.helpers import Data_Sampler
from utils.logger import logger, setup_logger
from agents.diffusion_ql import Diffusion_QL
from agents.flow_ql import Flow_QL
from agents.mmd_ql import MoMa_QL
from agents.mmd_ql_online import MoMa_QL as MoMa_QL_Online  # Online version for second phase
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from collections import deque


available_agents = {
    'dql': Diffusion_QL,
    'mmd_ql': MoMa_QL,
    'flow_ql': Flow_QL
}

available_agents_online = {
    'dql': Diffusion_QL,  # TODO: may need online version
    'mmd_ql': MoMa_QL_Online,
    'flow_ql': Flow_QL  # TODO: may need online version
}

def train_agent(output_dir, writer, args):
    # Process device first
    if torch.cuda.is_available() and hasattr(args, 'device'):
        if isinstance(args.device, int):
            args.device = f"cuda:{args.device}"
        elif args.device == "cuda":
            args.device = "cuda"
        else:
            args.device = args.device  # Use as is if it's already a valid device string
    else:
        args.device = "cpu"
    
    # Load offline buffer
    if not os.path.exists(f'dataset/{args.env_name}.pkl'):
        download_dataset(args.env_name)
        
    with open(f'dataset/{args.env_name}.pkl', 'rb') as f:
        offline_dataset = pickle.load(f)
    
    # turn the model_args to a dict
    model_args = {}
    for item in args.model_args:
        model_args.update(item)
    data_sampler = Data_Sampler(offline_dataset, args.device, model_args['reward_tune'])
    utils.print_banner('Loaded offline buffer')
    
    # Create environment for online training (will be used in phase 2)
    train_env = gym.make(args.env_name)
    train_env.seed(args.seed)
    
    # Create online replay buffer for phase 2
    online_buffer = {
        'observations': [],
        'actions': [],
        'next_observations': [],
        'rewards': [],
        'terminals': []
    }
    online_buffer_capacity = getattr(args, 'online_buffer_capacity', 200000)
    
    # Determine when to switch to online training (based on epochs, not iterations)
    offline_epochs = getattr(args, 'offline_epochs', int(args.num_epochs * 0.2))
    offline_timesteps = offline_epochs * args.num_steps_per_epoch
    
    # Balanced sampling ratio (如果为1.0则pure online，0.5则50-50混合)
    balanced_ratio = getattr(args, 'balanced_ratio', 0.5)  # 50% offline, 50% online
    
    utils.print_banner(f'Will switch to online training after {offline_epochs} epochs ({offline_timesteps} iterations)')
    utils.print_banner(f'Balanced sampling ratio: {balanced_ratio} (1.0=pure online, 0.0=pure offline)')
    
    # 添加必需的参数到 model_args
    model_args.update({
        'state_dim': args.state_dim,
        'action_dim': args.action_dim,
        'max_action': args.max_action,
        'device': args.device
    })
    
    # 处理参数名映射 (如果需要的话)
    if 'gn' in model_args:
        model_args['grad_norm'] = model_args.pop('gn')  # gn -> grad_norm
    
    agent = available_agents[args.model](**model_args)

    early_stop = False
    stop_check = utils.EarlyStopping(tolerance=1, min_delta=0.)

    # Load previous evaluations if exist (to keep best-model selection consistent when resuming)
    evaluations = []
    eval_file_path = os.path.join(output_dir, "eval.npy")
    if os.path.exists(eval_file_path):
        try:
            loaded_evals = np.load(eval_file_path, allow_pickle=True)
            # ensure python list
            evaluations = loaded_evals.tolist() if hasattr(loaded_evals, 'tolist') else list(loaded_evals)
        except Exception:
            # if loading fails, start fresh
            evaluations = []

    # Default start - train from scratch
    training_iters = 0
    max_timesteps = args.num_epochs * args.num_steps_per_epoch
    metric = 100.
    is_online_mode = False
    utils.print_banner(f"Offline-to-Online Training Start (Train from scratch)", separator="*", num_star=90)
    start_time = time.time()
    
    # Phase 1: Offline training
    utils.print_banner("Phase 1: OFFLINE Training", separator="=", num_star=90)

    
    while (training_iters < max_timesteps) and (not early_stop):
        # Check if we should switch to online mode
        if not is_online_mode and training_iters >= offline_timesteps:
            is_online_mode = True
            
            utils.print_banner("Phase 1 Complete - Selecting Best Offline Model", separator="=", num_star=90)
            
            # Find best model from offline training (based on normalized reward)
            if len(evaluations) > 0:
                offline_scores = np.array(evaluations)
                best_offline_id = np.argmax(offline_scores[:, 2])  # normalized score
                best_epoch = int(offline_scores[best_offline_id, -1])
                utils.print_banner(f"Best offline model at epoch {best_epoch} with normalized score {offline_scores[best_offline_id, 2]:.2f}")
                
                # Load the best offline model
                agent.load_model(output_dir, id=best_epoch)
                utils.print_banner(f"Loaded best offline model from epoch {best_epoch}")
            else:
                utils.print_banner("Warning: No evaluation data, using current model")
            
            utils.print_banner("Phase 2: ONLINE Training (Direct Interaction)", separator="=", num_star=90)
            
            # Create new online agent based on offline agent's weights
            online_agent = available_agents_online[args.model](**model_args)
            
            # Transfer weights from offline agent to online agent
            if hasattr(agent, 'actor'):
                online_agent.actor.load_state_dict(agent.actor.state_dict())
                online_agent.ema_model.load_state_dict(agent.ema_model.state_dict())
            if hasattr(agent, 'critic'):
                online_agent.critic.load_state_dict(agent.critic.state_dict())
                online_agent.critic_target.load_state_dict(agent.critic_target.state_dict())
            
            # Sync training step counter
            online_agent.step = agent.step
            
            # Replace agent with online version
            agent = online_agent
            utils.print_banner("Switched to online agent with transferred weights")
        
        # Training logic based on mode
        if is_online_mode:
            # Online training: 环境交互 + buffer采样训练
            iterations = int(args.eval_freq * args.num_steps_per_epoch)
            
            # Step 1: 收集online数据
            utils.print_banner(f"Collecting {iterations} online transitions...")
            current_state = train_env.reset() if not hasattr(train_agent, '_online_state') else train_agent._online_state
            
            for _ in range(iterations):
                # 与环境交互
                action = agent.sample_action(current_state)
                next_state, reward, done, _ = train_env.step(action)
                
                # 添加到online buffer
                online_buffer['observations'].append(current_state)
                online_buffer['actions'].append(action)
                online_buffer['next_observations'].append(next_state)
                online_buffer['rewards'].append(reward)
                online_buffer['terminals'].append(done)
                
                # 维持buffer容量
                if len(online_buffer['observations']) > online_buffer_capacity:
                    for key in online_buffer:
                        online_buffer[key].pop(0)
                
                current_state = next_state if not done else train_env.reset()
            
            train_agent._online_state = current_state
            
            # Step 2: 创建online data sampler
            online_dataset = {
                'observations': np.array(online_buffer['observations']),
                'actions': np.array(online_buffer['actions']),
                'next_observations': np.array(online_buffer['next_observations']),
                'rewards': np.array(online_buffer['rewards']).reshape(-1, 1),
                'terminals': np.array(online_buffer['terminals']).reshape(-1, 1)
            }
            online_sampler = Data_Sampler(online_dataset, args.device, model_args['reward_tune'])
            
            utils.print_banner(f"Online buffer size: {len(online_buffer['observations'])}")
            
            # Step 3: Balanced sampling训练
            if training_iters > args.num_steps_per_epoch * args.num_epochs * 0.1:
                use_grad = True
            else:
                use_grad = False
            
            # 使用balanced sampling进行训练
            loss_metric = agent.train_with_balanced_sampling(
                offline_sampler=data_sampler,
                online_sampler=online_sampler,
                iterations=iterations,
                batch_size=args.batch_size,
                balanced_ratio=balanced_ratio,
                log_writer=writer,
                use_grad=use_grad
            )
        else:
            # Offline training: use replay buffer
            iterations = int(args.eval_freq * args.num_steps_per_epoch)
            if training_iters > args.num_steps_per_epoch * args.num_epochs * 0.1:
                use_grad = True
            else:
                use_grad = False
            
            loss_metric = agent.train(
                data_sampler,
                iterations=iterations,
                batch_size=args.batch_size,
                log_writer=writer,
                use_grad=use_grad
            )
        
        training_iters += iterations
        curr_epoch = int(training_iters // int(args.num_steps_per_epoch))
        curr_time = time.time()

        bc_loss = np.mean(loss_metric['bc_loss'])
        ql_loss = np.mean(loss_metric['ql_loss'])
        actor_loss = np.mean(loss_metric['actor_loss'])
        critic_loss = np.mean(loss_metric['critic_loss'])
        used_time = curr_time - start_time

        # Logging
        mode_str = "ONLINE" if is_online_mode else "OFFLINE"
        utils.print_banner(f"Train step: {training_iters} [{mode_str}]", separator="*", num_star=90)
        logger.record_tabular('Trained Epochs', curr_epoch)
        logger.record_tabular('Training Mode', mode_str)
        logger.record_tabular('BC Loss', bc_loss)
        logger.record_tabular('QL Loss', ql_loss)
        logger.record_tabular('Actor Loss', actor_loss)
        logger.record_tabular('Critic Loss', critic_loss)
        logger.record_tabular('Time', used_time)

        writer.add_scalar(f"charts/time", used_time, training_iters)


        # Evaluation
        eval_res, eval_res_std, eval_norm_res, eval_norm_res_std = eval_policy(
            agent,
            args.env_name,
            args.seed,
            eval_episodes=args.eval_episodes,
            cfg_scale=model_args.get('cfg_scale', None)
        )
        bc_loss = np.mean(loss_metric['bc_loss'])
        ql_loss = np.mean(loss_metric['ql_loss'])
        actor_loss = np.mean(loss_metric['actor_loss'])
        critic_loss = np.mean(loss_metric['critic_loss'])

        evaluations.append([eval_res, eval_res_std, eval_norm_res, eval_norm_res_std,
                            bc_loss, ql_loss, actor_loss, critic_loss, curr_epoch])
        np.save(os.path.join(output_dir, "eval"), evaluations)
        logger.record_tabular('Average Episodic Reward', eval_res)
        logger.record_tabular('Average Episodic N-Reward', eval_norm_res)
        logger.dump_tabular()

        writer.add_scalar(f"eval_charts/bc_loss", bc_loss, training_iters)
        writer.add_scalar(f"eval_charts/ql_loss", ql_loss, training_iters)
        writer.add_scalar(f"eval_charts/actor_loss", actor_loss, training_iters)
        writer.add_scalar(f"eval_charts/critic_loss", critic_loss, training_iters)
        writer.add_scalar(f"eval_charts/eval_reward", eval_res, training_iters)
        writer.add_scalar(f"eval_charts/eval_reward_std", eval_res_std, training_iters)
        writer.add_scalar(f"eval_charts/eval_norm_reward", eval_norm_res, training_iters)
        writer.add_scalar(f"eval_charts/eval_norm_reward_std", eval_norm_res_std, training_iters)

        if args.early_stop:
            early_stop = stop_check(metric, bc_loss)

        metric = bc_loss

        if args.save_best_model:
            agent.save_model(output_dir, curr_epoch)

    # Model Selection: online （based on normalized score） or offline (based on bc loss)
    scores = np.array(evaluations)
    # online ms
    online_best_id = np.argmax(scores[:, 2])
    # offline ms
    bc_loss = scores[:, 4]
    top_k = min(len(bc_loss) - 1, args.top_k)
    offline_best_id = np.argsort(bc_loss)[top_k]

    if args.ms == 'online':
        best_res = {'model selection': args.ms, 'epoch': scores[online_best_id, -1],
                    'best normalized score avg': scores[online_best_id, 2],
                    'best normalized score std': scores[online_best_id, 3],
                    'best raw score avg': scores[online_best_id, 0],
                    'best raw score std': scores[online_best_id, 1]}
        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))
    elif args.ms == 'offline':
        best_res = {'model selection': args.ms, 'epoch': scores[offline_best_id][-1],
                    'best normalized score avg': scores[offline_best_id][2],
                    'best normalized score std': scores[offline_best_id][3],
                    'best raw score avg': scores[offline_best_id][0],
                    'best raw score std': scores[offline_best_id][1]}

        with open(os.path.join(output_dir, f"best_score_{args.ms}.txt"), 'w') as f:
            f.write(json.dumps(best_res))

    writer.close()

    if args.save_best_model:
        # Clean up other pth files except for the selected ones
        # Create a dictionary of old and new filenames
        mapping = {
            f"actor_{int(scores[online_best_id, -1])}.pth": f"actor_online.pth",
            f"critic_{int(scores[online_best_id, -1])}.pth": f"critic_online.pth",
            f"actor_{int(scores[offline_best_id, -1])}.pth": f"actor_offline.pth",
            f"critic_{int(scores[offline_best_id, -1])}.pth": f"critic_offline.pth",
        }

        if online_best_id == offline_best_id:
            for key in mapping.keys():
                shutil.copyfile(os.path.join(output_dir, key), os.path.join(output_dir, mapping[key]))
            for file_name in os.listdir(output_dir):
                # Check if the file is in the mapping dictionary
                # only for file end with .pth
                if file_name.endswith(".pth"):
                    if file_name not in mapping.values():
                        # Delete any files not in the mapping dictionary
                        os.remove(os.path.join(output_dir, file_name))
        else:
            for file_name in os.listdir(output_dir):
                # Check if the file is in the mapping dictionary
                # only for file end with .pth
                if file_name.endswith(".pth"):
                    if file_name in mapping.keys():
                        # Rename the old file name to the new file name
                        os.rename(os.path.join(output_dir, file_name), os.path.join(output_dir, mapping[file_name]))
                    else:
                        if file_name not in mapping.values():
                            # Delete any files not in the mapping dictionary
                            os.remove(os.path.join(output_dir, file_name))


# -----------------------------------------------------------------------------#
# eval policy -----------------------------#
# -----------------------------------------------------------------------------#

# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10, cfg_scale=None):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    scores = []
    for i in range(eval_episodes):
        traj_return = 0.
        state, done = eval_env.reset(), False
        while not done:
            # Apply CFG via cfg_scale if the policy supports it
            if cfg_scale is not None:
                action = policy.sample_action(np.array(state), cfg_scale=cfg_scale)
            else:
                action = policy.sample_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            traj_return += reward
        scores.append(traj_return)

    avg_reward = np.mean(scores)
    std_reward = np.std(scores)

    normalized_scores = [eval_env.get_normalized_score(s) for s in scores]
    avg_norm_score = eval_env.get_normalized_score(avg_reward)
    std_norm_score = np.std(normalized_scores)
    
    utils.print_banner(f"Evaluation over {eval_episodes} episodes: {avg_reward:.2f} {avg_norm_score:.2f}")
    return avg_reward, std_reward, avg_norm_score, std_norm_score


if __name__ == "__main__":
    #import os
    #from multiprocessing import cpu_count 
    #cpu_num = cpu_count()
    #cpu_use = 4
    #cur_pid = os.getpid()
    #os.sched_setaffinity(cur_pid, list(range(cpu_num))[:cpu_use])
    #print(f"set the max number of cpu used to {cpu_use}")
    
    parser = argparse.ArgumentParser()
    # wrap all the arguments in a yaml file
    parser.add_argument("--config_path", type=str, default="./configs/halfcheetah-medium-v2.yaml") # including hyper-parameters
    # load the trained weights
    parser.add_argument("--load_weights", action='store_true', help="Load trained weights")
    # load weights from specific epoch
    parser.add_argument("--load_epoch", type=int, default=None, help="Load weights from specific epoch (e.g., 500). Automatically enables --load_weights")
    args = parser.parse_args()
    
    # Auto-enable load_weights if load_epoch is specified
    if args.load_epoch is not None:
        args.load_weights = True
    
    if args.config_path is not None and os.path.exists(args.config_path):
        with open(args.config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Merge config into args, config values will overwrite command-line values
        for key, value in config.items():
            setattr(args, key, value)
    
    else:
        raise ValueError(f"Config file {args.config_path} not found")
    
    
    
    args.output_dir = f'{args.dir}'
    
    # 设置默认的 ms 和 top_k 参数（如果配置文件中没有指定）
    if not hasattr(args, 'ms'):
        args.ms = 'offline'  # 默认使用 offline 模式
    if not hasattr(args, 'top_k'):
        args.top_k = 0  # 默认 top_k = 1
    
    # 构建文件名，包含 ms 信息
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_name = f"{timestamp}|{args.model}"#{args.env_name}|{args.exp}||ms-{args.ms}"
    if args.ms == 'offline':
        file_name += f'|k-{args.top_k}'
    file_name += f'|{args.model_args[-3]}|{args.model_args[-1]}|{args.model_args[-2]}'
    file_name += f'|{args.model_args[0]}|{args.model_args[1]}'
    
    results_dir = os.path.join(args.output_dir, file_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    utils.print_banner(f"Saving location: {results_dir}")
    
    writer = SummaryWriter(log_dir=results_dir)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    
    
    # Setup Logging
    
    if os.path.exists(os.path.join(results_dir, 'variant.json')):
        os.remove(os.path.join(results_dir, 'variant.json'))
    
    # print(gym.envs.registry.env_specs.keys())
    env = gym.make(args.env_name)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])
    
    variant = vars(args)
    variant.update(state_dim=state_dim)
    variant.update(action_dim=action_dim)
    variant.update(max_action=max_action)
    
    setup_logger(os.path.basename(results_dir), variant=variant, log_dir=results_dir)
    utils.print_banner(f"Env: {args.env_name}, state_dim: {state_dim}, action_dim: {action_dim}")

    train_agent(results_dir,
                writer,
                args)