"""
训练辅助模块，提供训练HER+SAC模型的功能
"""
import os
import numpy as np
import torch
import random
import matplotlib.pyplot as plt
from collections import deque
import yaml
import time
from src.gift.agents.sac_agent import SAC_HER_Agent
from src.gift.buffers.her_buffer import HERReplayBuffer, collect_her_samples
from src.gift.utils.evaluator import evaluate_agent, evaluate_cip
from src.gift.utils.visualization import plot_training_history
from src.data.cip_dataset import CIPDataset, get_dataloader
from src.gift.agents.scrl_agent import SCRL_Agent

def set_random_seeds(seed=42):
    """设置随机种子以确保可重复性"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def train(dataset_collection, config, model_save_path='her_treatment_policy.pth', use_amp=False, load=False, logger=None):
    """
    训练治疗策略模型，或加载已有模型
    
    参数:
    dataset_collection: 包含train_f和val_f的数据集集合
    config: 配置参数字典
    model_save_path: 模型保存路径
    use_amp: 是否使用混合精度训练
    load: 是否直接加载已有模型而不进行训练
    
    返回:
    agent: 训练好的或加载的智能体
    metrics: 评估指标
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(config)
    min_history_length = config['model']['her_params']['min_history_length']
    max_history_length = config['model']['her_params']['max_history_length']
    future_length = config['model']['her_params']['future_length']
    buffer_size = config['model']['her_params']['buffer_size']
    batch_size = config['exp']['batch_size']
    
    goal_threshold = config['model']['her_params']['goal_threshold']
    k_future = config['model']['her_params']['k_future']
    reward_mode = config['model']['her_params']['reward_mode']
    hidden_dim = config['model']['hidden_dim']
    use_attention = config['model']['use_attention']
    attention_heads = config['model']['attention_heads']
    DR = config['model']['sac_params']['DR']
    recover = config['model']['sac_params']['recover']
    action_diff = config['model']['sac_params']['action_diff']
    is_cip = config['exp']['evaluation_mode'] == 'cip'
    use_cql = config['model']['sac_params']['use_cql']
    agent = SAC_HER_Agent(
        dataset_collection,
        config,
        input_dim=config['dataset'].get('input_size'),
        output_dim=config['dataset'].get('output_size'),
        treatment_dim=config['dataset'].get('treatment_size'),
        static_dim=config['dataset'].get('static_size'),
        hidden_dim=hidden_dim,
        future_length=future_length,
        buffer_size=buffer_size,
        batch_size=batch_size,
        goal_threshold=goal_threshold,
        k_future=k_future,
        use_amp=use_amp,
        reward_mode=reward_mode,
        use_attention=use_attention,
        num_heads=attention_heads if use_attention else None,
        discount=config['model']['sac_params']['discount'],
        beta=config['model']['sac_params']['beta'],
        lr=config['model']['sac_params']['lr'],
        alpha=config['model']['sac_params']['alpha'],
        use_automatic_entropy=config['model']['sac_params']['use_automatic_entropy'],
        DR=DR,
        recover=recover,
        action_diff=action_diff,
        use_cql=use_cql,
        input_x=config['dataset'].get('input_x'),
    )
    save_path = f"sac_{model_save_path}"
    
    if load:
        try:
            print(f"\n尝试加载模型 {save_path}...")
            agent.load(save_path)
            print(f"模型加载成功！")
        except FileNotFoundError:
            print(f"模型文件 {save_path} 不存在，将进行训练...")
            load = False  
        except Exception as e:
            print(f"加载模型时出错: {e}")
            print("将进行训练...")
            load = False  
    
    if not load:
        replay_buffer = agent.memory
        collect_her_samples(
            dataset_collection,
            replay_buffer,
            min_history_length=min_history_length,
            max_history_length=max_history_length,
            future_length=future_length
        )

        r = 0
        for item in replay_buffer.buffer:
            if item[2] == 0:
                r += 1
        print(f"all:{len(replay_buffer.buffer)}, hit:{r}")
        
        training_iterations = config['exp']['max_epochs'] * (len(replay_buffer) // config['exp']['batch_size'])
        print(f"training_iterations:{training_iterations}, len(replay_buffer):{len(replay_buffer)}")
        print(f"\n开始训练 SAC 智能体...")
        start_time = time.time()
        losses = agent.train_offline(training_iterations, progress_interval=config['exp']['log_freq'])
        train_time = time.time() - start_time
        agent.save(save_path)
        print(f"模型已保存到 {save_path}")
    print("\n评估模型...")
    start_time = time.time()
    metrics, complexity_info = evaluate(agent, dataset_collection, config, logger=logger)
    test_time = time.time() - start_time
    complexity_info['train_time'] = train_time
    complexity_info['test_time'] = test_time
    
    return agent, metrics, complexity_info



def train_scrl(dataset_collection, config, model_save_path='scrl_treatment_policy.pth', use_amp=False, load=False, logger=None):
    """
    训练 SCRL 治疗策略模型，或加载已有模型
    
    参数:
    dataset_collection: 包含train_f和val_f的数据集集合
    config: 配置参数字典
    model_save_path: 模型保存路径
    use_amp: 是否使用混合精度训练
    load: 是否直接加载已有模型而不进行训练
    
    返回:
    agent: 训练好的或加载的智能体
    metrics: 评估指标
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("--- 训练 SCRL Agent ---")
    print(config)
    
    min_history_length = config['model']['her_params']['min_history_length']
    max_history_length = config['model']['her_params']['max_history_length']
    future_length = config['model']['her_params']['future_length']
    buffer_size = config['model']['her_params']['buffer_size']
    batch_size = config['exp']['batch_size']
    
    goal_threshold = config['model']['her_params']['goal_threshold']
    k_future = config['model']['her_params']['k_future']
    reward_mode = config['model']['her_params']['reward_mode']
    use_attention = config['model']['use_attention']
    attention_heads = config['model']['attention_heads']
    bc_reg_lambda = config['model']['scrl_params'].get('bc_reg_lambda', 0.1)
    use_data_aug = config['model']['scrl_params'].get('use_data_aug', True)
    agent = SCRL_Agent(
        dataset_collection,
        config,
        input_dim=config['dataset'].get('input_size'),
        output_dim=config['dataset'].get('output_size'),
        treatment_dim=config['dataset'].get('treatment_size'),
        static_dim=config['dataset'].get('static_size'),
        future_length=future_length,
        buffer_size=buffer_size,
        batch_size=batch_size,
        goal_threshold=goal_threshold,
        k_future=k_future,
        use_amp=use_amp,
        reward_mode=reward_mode,
        use_attention=use_attention,
        num_heads=attention_heads if use_attention else None,
        discount=config['model']['sac_params']['discount'], 
        lr=config['model']['sac_params']['lr'],
        bc_reg_lambda=bc_reg_lambda,
        use_data_aug=use_data_aug
    )
    save_path = f"scrl_{model_save_path}"
    
    if load:
        try:
            print(f"\n尝试加载模型 {save_path}...")
            agent.load(save_path)
            print(f"模型加载成功！")
        except FileNotFoundError:
            print(f"模型文件 {save_path} 不存在，将进行训练...")
            load = False  
        except Exception as e:
            print(f"加载模型时出错: {e}")
            print("将进行训练...")
            load = False  
    
    if not load:
        replay_buffer = agent.memory
        
        collect_her_samples(
            dataset_collection,
            replay_buffer,
            min_history_length=min_history_length,
            max_history_length=max_history_length,
            future_length=future_length
        )

        r = 0
        for item in replay_buffer.buffer:
            if item[2] == 0: 
                r += 1
        print(f"Buffer populated. all:{len(replay_buffer.buffer)}, hit:{r}")
        
        training_iterations = config['exp']['max_epochs'] * (len(replay_buffer) // config['exp']['batch_size'])
        print(f"training_iterations:{training_iterations}, len(replay_buffer):{len(replay_buffer)}")
        print(f"\n开始训练 SCRL 智能体...")
        start_time = time.time()
        losses = agent.train_offline(training_iterations, progress_interval=config['exp']['log_freq'])
        train_time = time.time() - start_time
        agent.save(save_path)
        print(f"模型已保存到 {save_path}")
    print("\n评估模型...")
    start_time = time.time()
    metrics, complexity_info = evaluate(agent, dataset_collection, config, logger=logger)
    test_time = time.time() - start_time
    if complexity_info is None: complexity_info = {} 
    complexity_info['train_time'] = train_time if not load else 0
    complexity_info['test_time'] = test_time
    
    return agent, metrics, complexity_info

def evaluate(agent, dataset_collection, config, max_tau=6, logger=None):
    """
    针对不同的tau值评估模型
    
    参数:
    agent: 训练好的智能体
    dataset_collection: 数据集合
    config: 配置参数
    max_tau: 最大tau值，如果为None则使用config中的tau值
    
    返回:
    all_metrics: 包含所有tau值评估结果的字典
    """
    orig_tau = config['exp']['tau']
    all_metrics = {}
    if 'mimic' in config['dataset']['name']:
        if config['exp']['test']:
            batch_size = int(config['dataset']['max_number'] * config['dataset']['split']['test'])
        else:
            batch_size = int(config['dataset']['max_number'] * config['dataset']['split']['val'])
    elif 'tumor' in config['dataset']['name']:
        if config['exp']['test']:
            batch_size = config['dataset']['num_patients']['test']
        else:
            batch_size = config['dataset']['num_patients']['val']

    for tau in range(1, max_tau + 1):
        agent.future_length = tau
        logger.info(f"\n评估模型，tau={tau}...")
        config['exp']['tau'] = tau
        
        if config['exp']['evaluation_mode'] == 'cip':
            if not config.exp.test:
                data = dataset_collection.val_f.data
            else:
                data = dataset_collection.test_f.data

            data_loader = get_dataloader(
                CIPDataset(data, config), 
                batch_size=batch_size,
                shuffle=False
            )
            metrics, complexity_info_tmp = evaluate_cip(
                agent,
                data_loader,
                dataset_collection,
                config
            )
            if tau == 1:
                complexity_info = complexity_info_tmp.copy()
        else:
            metrics = evaluate_agent(
                agent,
                dataset_collection,
                num_episodes=config['exp']['eval_episodes']
            )
            complexity_info = None
        
        all_metrics[tau] = metrics
        print_evaluation_results(metrics, tau, logger)
    config['exp']['tau'] = orig_tau
    print(f"complexity_info:{complexity_info}")
    return all_metrics, complexity_info

def print_evaluation_results(metrics, tau, logger):
    logger.info(f"GIFT 评估结果 (tau={tau}):")
    logger.info(f"  成功率: {metrics['success_rate']:.2%}")
    logger.info(f"  平均MSE: {metrics['avg_mse']:.6f}")
    logger.info(f"  平均RMSE: {metrics['avg_rmse']:.6f}")
    logger.info(f"  平均使用步数: {metrics['avg_steps_used']:.2f}/{tau}")
    logger.info(f"  提前停止率: {metrics['early_stop_rate']:.2%}")