import numpy as np
import torch
import gym
import argparse
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import random
import sys
import json
from pathlib import Path
import yaml
import time

import setproctitle
from algo.ROOT import ROOT
import algo.utils as utils
from envs.env_utils import call_terminal_func
from envs.common import call_env
from tensorboardX import SummaryWriter
from algo.get_normalized_score import get_normalized_score
import d4rl
import robust_ot_solver
import gc

def eval_policy(policy, env, eval_episodes=10, eval_cnt=None):
    eval_env = env

    avg_reward = 0.
    for episode_idx in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            next_state, reward, done, _ = eval_env.step(action)

            avg_reward += reward
            state = next_state
    avg_reward /= eval_episodes

    print("[{}] Evaluation over {} episodes: {}".format(eval_cnt, eval_episodes, avg_reward))

    return avg_reward


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="./logs")
    parser.add_argument("--policy", default="ROOT")
    parser.add_argument("--env", default="halfcheetah")
    parser.add_argument("--seed", default=0, type=int)            
    parser.add_argument("--save-model", action="store_true")        # Save model and optimizer parameters
    parser.add_argument('--srctype', default='medium', type=str)
    parser.add_argument("--tartype", default='medium', type=str)
    parser.add_argument("--steps", default=1e6, type=int)
    parser.add_argument("--weight", action="store_true")
    parser.add_argument("--noreg", action="store_true")
    parser.add_argument("--reg_weight", default=0.5, type=float)
    
    # Parameter used to calculate optimal transport
    parser.add_argument("--epsilon", default=0.01, type=float, help="Entropy regularization")
    parser.add_argument("--metric", default='euclidean', type=str)     # metric used in optimal transport
    parser.add_argument("--lambda_src", default=0.05, type=float, help="Source filtering strength")
    parser.add_argument("--lambda_tar", default=0.5, type=float, help="Target optimism strength")

    # Parameter used to filter low quality datasets
    parser.add_argument("--filter_threshold", default=1.0, type=float, help= "ratio used to filter dataset")
    
    
    
    args = parser.parse_args()

    with open(f"{str(Path(__file__).parent.absolute())}/config/{args.env.replace('-', '_')}.yaml", 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    print("------------------------------------------------------------")
    print("Env: {}, Seed: {}".format( args.env, args.seed))
    print("------------------------------------------------------------")
    
    base_exp_name = args.env + '-srcdatatype-' + args.srctype + '-tardatatype-' + args.tartype
    hyper_param_str = f"eps-{args.epsilon}|src-{args.lambda_src}|tar-{args.lambda_tar}|filter-{args.filter_threshold}"
    proctitle_name= f"{base_exp_name}|{args.filter_threshold}|{args.lambda_src}|{args.lambda_tar}|{args.seed}"
    setproctitle.setproctitle(f"{proctitle_name}")
    outdir = os.path.join(args.dir,  base_exp_name, hyper_param_str, f"r{args.seed}")


    done_file_path = os.path.join(outdir, 'DONE')
    if os.path.exists(done_file_path):
        print(f"[{time.ctime()}] Experiment already finished (found DONE file). Skipping: {outdir}")
        sys.exit(0) # 直接退出，Shell 脚本会继续执行下一条命令
    
    # --- [新增代码] 准备断点续训变量 ---
    state_file_path = os.path.join(outdir, 'training_state.json')
    start_step = 0
    resume_mode = False

    if os.path.exists(state_file_path) and args.save_model:
        try:
            with open(state_file_path, 'r') as f:
                training_state = json.load(f)
                start_step = training_state['step']
                resume_mode = True
            print(f"[{time.ctime()}] Found checkpoint. Resuming from step {start_step}...")
        except Exception as e:
            print(f"Error loading checkpoint state: {e}. Starting from scratch.")

    #writer = SummaryWriter('{}/tb'.format(outdir))
    log_dir = '{}/tb'.format(outdir)
    writer = SummaryWriter(log_dir=log_dir, purge_step=start_step if resume_mode else None)
    if args.save_model and not os.path.exists("{}/models".format(outdir)):
        os.makedirs("{}/models".format(outdir))
    
    if '_' in args.env:
        args.env = args.env.replace('_', '-')
    
    # train env
    src_env_name = args.env.split('-')[0] + '-' + args.srctype + '-v2'
    src_env = gym.make(src_env_name)
    src_env.seed(args.seed)
    # test env
    tar_env = call_env(config['tar_env_config'])
    tar_env.seed(args.seed)
    # eval env
    src_eval_env = gym.make(src_env_name)
    src_eval_env.seed(args.seed + 100)
    tar_eval_env = call_env(config['tar_env_config'])
    tar_eval_env.seed(args.seed + 100)

    # seed all
    src_env.action_space.seed(args.seed)
    tar_env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)

    state_dim = src_env.observation_space.shape[0]
    action_dim = src_env.action_space.shape[0] 
    max_action = float(src_env.action_space.high[0])
    min_action = -max_action
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config['metric'] = args.metric

    weight = True if args.weight else False
    noreg = True if args.noreg else False

    config.update({
        'state_dim': state_dim,
        'action_dim': action_dim,
        'max_action': max_action,
        'weight': weight,
        'noreg': noreg,
        'reg_weight': args.reg_weight,
        'filter_threshold': args.filter_threshold
    })

    policy = ROOT(config, device)
    
    if resume_mode:
        model_path = '{}/models/model'.format(outdir)
        # 注意：这里假设你的 policy.load 能够处理路径，或者你需要指向具体的 .pt 文件
        # 如果你的 save 存的是 model_actor.pt 等，这里可能需要调整
        try:
            policy.load(model_path) 
            print("Successfully loaded model weights.")
        except Exception as e:
            print(f"Warning: Could not load model weights ({e}). Training from scratch.")
            start_step = 0 # 加载失败则重头开始

    ## write logs to record training parameters
    with open(os.path.join(outdir, 'log.txt'), 'w') as f:
        f.write('\n Env: {}, seed: {}'.format( args.env, args.seed))
        for item in config.items():
            f.write('\n {}'.format(item))

    src_replay_buffer = utils.OTReplayBuffer(state_dim, action_dim, device)
    tar_replay_buffer = utils.ReplayBuffer(state_dim, action_dim, device)

    # load offline datasets
    src_dataset = d4rl.qlearning_dataset(src_env)
    tar_dataset = utils.call_tar_dataset(args.env, args.tartype)

    src_replay_buffer.convert_D4RL(src_dataset)
    tar_replay_buffer.convert_D4RL(tar_dataset)

    weight_filename = "robust_weights.hdf5"
    weight_file_path = os.path.join(outdir, weight_filename)
    
    print(f"Target weight file path: {weight_file_path}")

    # 2. 检查或计算权重
    if os.path.exists(weight_file_path) :
        print("Found existing weights in log dir. Loading directly...")
        robust_weights = utils.load_robust_weights(weight_file_path)
    else:
        print("Weights not found in log dir. Calculating now...")
        
        # 计算并保存到 outdir
        robust_weights = robust_ot_solver.compute_and_save_weights(
            src_replay_buffer, 
            tar_replay_buffer, 
            weight_file_path, 
            args
        )
    
    # 3. Empty PyTorch Cache (in case JAX/PyTorch interactions left fragments)
    torch.cuda.empty_cache()
    # 4. 将权重注入 Buffer

    global_mean = robust_weights.mean() + 1e-12
    global_max = robust_weights.max()
    
    print(f"Global Weights Stats -> Mean: {global_mean:.2e}, Max: {global_max:.2e}")
    normalized_weights = robust_weights / global_mean

    src_replay_buffer.set_weights(normalized_weights)
    print(f"Robust weights ready. Mean: {normalized_weights.mean():.4e}")
    eval_cnt = start_step // config['eval_freq']

    # whether to pretrain VAE
    if not noreg and not resume_mode: 
        print("Starting VAE pretraining...")
        policy.train_vae(tar_replay_buffer, config['batch_size'], writer)
    elif not noreg and resume_mode:
        print("Skipping VAE pretraining (Resuming mode).")
    
    start_time = time.time()
    for t in range(start_step, int(args.steps)):
        policy.train(src_replay_buffer, tar_replay_buffer, config['batch_size'], writer)

        if (t + 1) % config['eval_freq'] == 0:
            end_time = time.time()
            total_time = end_time - start_time
            print(f"\nTotal training time {total_time:.2f} seconds, current epoch is {int((t + 1)/ config['eval_freq'])}")
            src_eval_return = eval_policy(policy, src_eval_env, eval_cnt=eval_cnt)
            tar_eval_return = eval_policy(policy, tar_eval_env, eval_cnt=eval_cnt)
            norm_tar_return = get_normalized_score(args.env, tar_eval_return)
            writer.add_scalar('test/source return', src_eval_return, global_step = t+1)
            writer.add_scalar('test/target return', tar_eval_return, global_step = t+1)
            writer.add_scalar('test/target_normalized_score', norm_tar_return, global_step=t+1)
            print("[{}] Normalized Score of Target Domain is {}".format(eval_cnt, norm_tar_return))
            print("*"*30)
            eval_cnt += 1

            if args.save_model:
                policy.save('{}/models/model'.format(outdir))
                
                # --- [新增代码] 保存当前的步数，用于下次续训 ---
                with open(state_file_path, 'w') as f:
                    json.dump({'step': t + 1}, f)
    writer.close()
    with open(done_file_path, 'w') as f:
        f.write(f"Finished at {time.ctime()}")
    print(f"Experiment completed. DONE file created at {done_file_path}")