import numpy as np
import random
import torch
import time
import torch.nn as nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from src.data.her_data_generator import create_history_treatment_goal_samples, convert_dataloader_to_samples

import torch
from fvcore.nn import FlopCountAnalysis
import numpy as np
from src.utils.utils import set_seed
from src.data.cip_dataset import CIPDataset, get_dataloader
from thop import profile, clever_format

def count_flops_params(agent, history_dict_batch, goal_batch):
    """
    计算 SAC_HER_Agent 模型 FLOPs 和参数数，严格参考 generate_treatment_plan_batch 输入处理。

    参数:
        agent: SAC_HER_Agent 实例，包含 encoder、actor、critic 等
        history_dict_batch: List[Dict]，每个 dict 包含对应样本的多组历史数据（numpy数组格式）
        goal_batch: List[np.ndarray 或 Tensor]，对应每个样本的目标数据

    返回:
        total_params: int，总参数数量
        mflops_per_sample: float，每个样本平均 MFLOPs
    """

    device = DEVICE
    batch_size = len(history_dict_batch)
    H_t_batch = {}
    for idx, history_dict in enumerate(history_dict_batch):
        for key in history_dict:
            if key not in H_t_batch:
                H_t_batch[key] = []
            if isinstance(history_dict[key], np.ndarray):
                H_t_batch[key].append(torch.FloatTensor(history_dict[key]))
            else:
                H_t_batch[key].append(history_dict[key])
    for key in H_t_batch:
        if isinstance(H_t_batch[key][0], torch.Tensor):
            H_t_batch[key] = torch.cat(H_t_batch[key], dim=0).to(device)
        else:
            pass
    goal_tensor_batch = []
    for goal in goal_batch:
        if isinstance(goal, np.ndarray):
            goal_tensor_batch.append(torch.FloatTensor(goal).unsqueeze(0))
        elif isinstance(goal, torch.Tensor):
            goal_tensor_batch.append(goal.unsqueeze(0) if goal.dim() == 1 else goal)
        else:
            raise ValueError("goal_batch contains unsupported type")
    goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)  
    macs, params = profile(agent.encoder, (H_t_batch, goal_tensor_batch), verbose=False)
    flops_encoder = macs * 2
    params_encoder = sum(p.numel() for p in agent.encoder.parameters())

    with torch.no_grad():
        encoded_state = agent.encoder(H_t_batch, goal_tensor_batch)
    if isinstance(encoded_state, tuple):
        actor_input_state = torch.cat(encoded_state, dim=1)
        h_encoding, g_encoding = encoded_state
    else:
        actor_input_state = encoded_state
    flops_actor = FlopCountAnalysis(agent.actor, (actor_input_state,)) 
    params_actor = sum(p.numel() for p in agent.actor.parameters())
    with torch.no_grad():
        action, _ = agent.actor(actor_input_state) 
    if hasattr(agent, 'algorithm') and agent.algorithm == "SCRL":
        flops_critic = FlopCountAnalysis(agent.critic, (h_encoding, action, g_encoding))
    else:
        flops_critic = FlopCountAnalysis(agent.critic, (actor_input_state, action))




    params_critic = sum(p.numel() for p in agent.critic.parameters())
    params_behavior = 0
    flops_behavior = 0
    if hasattr(agent, 'behavior_policy'):
        flops_behavior = FlopCountAnalysis(agent.behavior_policy, (encoded_state, action))
        params_behavior = sum(p.numel() for p in agent.behavior_policy.parameters())
        total_flops = flops_encoder + flops_actor.total() + flops_critic.total() + flops_behavior.total()
    else:
        total_flops = flops_encoder + flops_actor.total() + flops_critic.total()
        
    total_params = params_encoder + params_actor + params_critic + params_behavior
    print(f"params_encoder:{params_encoder}, params_actor: {params_actor}, params_critic:{params_critic}, params_behavior:{params_behavior}")
    mflops_per_sample = total_flops / 1e6 / batch_size

    print(f'Total Parameters: {total_params:,}')
    print(f'Total FLOPs: {total_flops:,} FLOPs, {mflops_per_sample:.2f} MFLOPs per sample (batch size={batch_size})')

    return total_params, mflops_per_sample

def evaluate_and_log_case_studies(agent, dataset_collection, config, logger, model_name, max_tau, case_study_results, size=100):
    """
    为指定的病人ID运行案例研究，并记录详细的轨迹信息。
    数据生成逻辑与 evaluate_cip 保持一致，使用 convert_dataloader_to_samples。

    参数:
        agent: 训练好的模型/智能体实例。
        dataloader: 用于生成样本的数据加载器。
        config: 实验的配置字典。
        case_study_ids (List[int]): 需要进行案例研究的病人ID列表。
        logger: 用于记录信息的日志记录器。
        model_name (str): 当前模型的名称（例如 'GIFT', 'VCIP'）。
        max_tau (int): 规划的未来时间步长（planning horizon）。
        case_study_results (dict): 用于累积存储案例研究结果的字典。函数会更新此字典。

    返回:
        dict: 更新后的 case_study_results 字典。
    """
    logger.info(f"--- Starting Case Study for model: {model_name} ---")
    start_time = time.time()
    set_seed(config['exp']['seed'])
    config['exp']['tau'] = 6
    if hasattr(agent, 'actor'):
        agent.actor.eval()
    if hasattr(agent, 'encoder'):
        agent.encoder.eval()
    data = dataset_collection.val_f.data if not config['exp']['test'] else dataset_collection.test_f.data
    dataloader = get_dataloader(CIPDataset(data, config), batch_size=len(data['outputs']), shuffle=False)
    all_samples = convert_dataloader_to_samples(dataloader)

    print(f"all_samples[0]: {all_samples[0][-1]}")
    for patient_id in range(size):
        try:
            if patient_id >= len(all_samples):
                logger.warning(f"Patient ID {patient_id} is out of bounds for the generated samples ({len(all_samples)}). Skipping.")
                continue

            logger.info(f"Processing patient_id: {patient_id}")
            history_dict, future_dict, goal = all_samples[patient_id]
            goal_np = goal if isinstance(goal, np.ndarray) else goal.cpu().numpy()
            if patient_id not in case_study_results:
                initial_outcome = history_dict['outputs'][:, -1, :]
                ground_truth_outcomes = future_dict['outputs'].squeeze()
                
                case_study_results[patient_id] = {
                    'initial_outcome': initial_outcome.squeeze(),
                    'goal': goal_np.squeeze(),
                    'ground_truth_outcomes': np.array(ground_truth_outcomes),
                    'models': {}
                }
            _, outputs_batch, _ = agent.generate_treatment_plan_batch(
                [history_dict], 
                [goal], 
                dataset_collection=dataset_collection, 
                future_dict_batch=[future_dict],
                future_length=max_tau,
                early_stop=False  
            )
            outcome_trajectory = outputs_batch[0]

            if patient_id == 0:
                print('-' * 100)
                print(f"evaluate log out: {outcome_trajectory.squeeze()}")
                print(f"evaluate log turth: {ground_truth_outcomes}")
            case_study_results[patient_id]['models'][model_name] = np.array(outcome_trajectory).squeeze()
            logger.info(f"Successfully recorded case study for patient {patient_id} with model {model_name}.")

        except IndexError as e:
            logger.error(f"IndexError while processing patient_id {patient_id}: {e}. Skipping.")
        except Exception as e:
            logger.error(f"An unexpected error occurred for patient_id {patient_id}: {e}")
            import traceback
            traceback.print_exc()
    if hasattr(agent, 'actor'):
        agent.actor.train()
    if hasattr(agent, 'encoder'):
        agent.encoder.train()

    logger.info(f"--- Case Study for model: {model_name} Finished ---")
    used_time = time.time() - start_time
    return case_study_results, used_time/size


def save_evaluation_results(agent, val_samples, predictions_list, treatments_list, true_treatments_list, output_path='./results/HER'):
    """
    将评估结果保存到文件中，每个样本单独保存为一个JSON文件
    """
    import os
    import json
    from datetime import datetime
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    absolute_path = os.path.abspath(output_path)
    print(f"绝对路径: {absolute_path}")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if agent:
        results_folder = os.path.join(output_path, f"results_{agent.algorithm}")
    else:
        results_folder = os.path.join(output_path, f"results")
    os.makedirs(results_folder, exist_ok=True)
    index_data = []
    for i, ((history_dict, future_dict, goal), predictions, treatments, true_treatments) in enumerate(
        zip(val_samples, predictions_list, treatments_list, true_treatments_list)):

        sample_id = f"sample_{i}"
        sample_data = {
            "sample_id": int(sample_id) if hasattr(sample_id, 'item') else sample_id,
            "history_outputs": history_dict['outputs'][0, :, :].astype(float).tolist(),
            "goal_value": [float(x) for x in goal],
            "treatments": treatments.astype(float).tolist(),
            "true_treatments": true_treatments.astype(float).tolist(),
            "final_output": [float(x) for x in predictions[-1]]
        }
        sample_file = os.path.join(results_folder, f"{sample_id}.json")
        with open(sample_file, 'w') as f:
            json.dump(sample_data, f, indent=2)
        index_data.append({
            "sample_id": sample_id,
            "history_length": len(sample_data["history_outputs"]),
            "treatment_length": len(sample_data["treatments"]),
            "goal_value": sample_data["goal_value"],
            "final_output": sample_data["final_output"],
            "file": f"{sample_id}.json"
        })
    index_file = os.path.join(results_folder, "index.json")
    with open(index_file, 'w') as f:
        json.dump(index_data, f, indent=2)

    print(f"评估结果已保存到 {results_folder}")
    return results_folder

def evaluate_agent(agent, dataset_collection, num_episodes=200):
    """
    评估智能体，使用create_history_treatment_goal_samples生成评估样本
    """
    if hasattr(agent, 'actor'):
        agent.actor.eval()
        agent.encoder.eval()
    val_samples = create_history_treatment_goal_samples(
        dataset_collection.val_f.data,
        min_history_length=10,
        max_history_length=20,
        future_length=agent.future_length
    )
    if len(val_samples) > num_episodes:
        val_samples = random.sample(val_samples, num_episodes)
    else:
        num_episodes = len(val_samples)

    print(f"使用 {len(val_samples)} 个验证样本进行评估")
    mse_values = []
    success_count = 0
    treatment_similarities = []
    steps_used_list = []
    for history_dict, future_dict, goal in val_samples:
        has_all_keys = True
        for key in ['outputs', 'static_features', 'current_treatments']:
            if key not in history_dict and key == 'outputs' and 'prev_outputs' in history_dict:
                history_dict['outputs'] = history_dict['prev_outputs']
            elif key not in history_dict:
                has_all_keys = False
                break

        if not has_all_keys:
            continue
        predictions, treatments, mse, steps_used = predict_trajectory(
            agent,
            history_dict,
            goal,
            dataset_collection,
            config
        )
        mse_values.append(mse)
        steps_used_list.append(steps_used)
        if mse < agent.goal_threshold:
            success_count += 1
        if 'current_treatments' in future_dict:
            actual_treatments = future_dict['current_treatments'].reshape(-1, 2)
            min_len = min(len(actual_treatments), len(treatments))
            if min_len > 0:
                actual_treatments = actual_treatments[:min_len]
                pred_treatments = treatments[:min_len]
                if np.linalg.norm(actual_treatments) > 0 and np.linalg.norm(pred_treatments) > 0:
                    similarity = np.sum(actual_treatments * pred_treatments) / (
                        np.linalg.norm(actual_treatments) * np.linalg.norm(pred_treatments)
                    )
                    treatment_similarities.append(similarity)
    if hasattr(agent, 'actor'):
        agent.actor.train()
        agent.encoder.train()
    success_rate = success_count / len(mse_values) if mse_values else 0
    avg_mse = np.mean(mse_values) if mse_values else float('inf')
    avg_rmse = np.sqrt(avg_mse) * scale_param
    avg_similarity = np.mean(treatment_similarities) if treatment_similarities else 0.0
    avg_steps = np.mean(steps_used_list) if steps_used_list else 0

    metrics = {
        'success_rate': success_rate,
        'avg_mse': avg_mse,
        'avg_rmse': avg_rmse,
        'avg_treatment_similarity': avg_similarity,
        'num_evaluated': len(mse_values),
        'avg_steps_used': avg_steps,
        'early_stop_rate': 1.0 - (avg_steps / agent.future_length) if steps_used_list else 0
    }

    print(f"评估结果 ({agent.algorithm}):")
    print(f" 成功率: {metrics['success_rate']:.2%}")
    print(f" 平均MSE: {metrics['avg_mse']:.6f}")
    print(f" 平均RMSE: {metrics['avg_rmse']:.6f}")
    print(f" 平均干预相似度: {metrics['avg_treatment_similarity']:.4f}")
    print(f" 平均使用步数: {metrics['avg_steps_used']:.2f}/{agent.future_length}")
    print(f" 提前停止率: {metrics['early_stop_rate']:.2%}")

    return metrics

def evaluate_cip(agent, dataloader, dataset_collection, config):
    """
    批量化版本的evaluate_cip函数，避免批次内显式循环，去除avg_treatment_similarity指标计算
    """
    complexity_info = {}
    num_episodes = config['exp']['eval_episodes']
    for H_t, _ in dataloader:
        batch_size = H_t['outputs'].shape[0]
        break

    if hasattr(agent, 'actor'):
        agent.actor.eval()
        agent.encoder.eval()
    val_samples = convert_dataloader_to_samples(dataloader)

    print(f"val_samples[0]: {val_samples[0][-1]}")
    if len(val_samples) > num_episodes:
        val_samples = random.sample(val_samples, num_episodes)
    else:
        num_episodes = len(val_samples)

    print(f"使用 {len(val_samples)} 个验证样本进行评估，批量大小: {batch_size}")
    valid_samples = []
    for history_dict, future_dict, goal in val_samples:
        has_all_keys = True
        for key in ['outputs', 'static_features', 'current_treatments']:
            if key not in history_dict and not (key == 'outputs' and 'prev_outputs' in history_dict):
                has_all_keys = False
                break
        if has_all_keys:
            if 'outputs' not in history_dict and 'prev_outputs' in history_dict:
                history_dict['outputs'] = history_dict['prev_outputs']
            valid_samples.append((history_dict, future_dict, goal))

    mse_values = []
    success_count = 0
    steps_used_list = []

    predictions_list = []
    treatments_list = []
    true_treatments_list = []
    processed_samples = []

    num_batches = (len(valid_samples) + batch_size - 1) // batch_size

    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(valid_samples))
        batch_samples = valid_samples[start_idx:end_idx]
        if not batch_samples:
            continue

        history_dict_batch = [s[0] for s in batch_samples]
        future_dict_batch = [s[1] for s in batch_samples]
        goal_batch = [s[2] for s in batch_samples]

        if batch_idx == 0 and config['exp']['tau'] == 1:
            complexity_info['params'], complexity_info['mflops'] = count_flops_params(agent, history_dict_batch, goal_batch)
        goal_batch_np = np.stack([g if isinstance(g, np.ndarray) else g.cpu().numpy() for g in goal_batch])
        actions_batch, outputs_batch, steps_taken_batch = agent.generate_treatment_plan_batch(
            history_dict_batch,
            goal_batch,
            dataset_collection,
            future_dict_batch,
            future_length=agent.future_length,
            early_stop=True
        )

        print(f"evaluate_cip outputs_batch[0]: {outputs_batch[0].squeeze()}")
        print(f"evaluate_cip future_dict_batch: {future_dict_batch[0]['outputs'].squeeze()}")
        if "mimic" in config['dataset']['name']:
            scale_param = dataset_collection.train_f.scaling_params['output_means']
        elif "tumor" in config['dataset']['name']:
            scale_param = dataset_collection.train_scaling_params[1]['cancer_volume']
        else:
            scale_param = 1.0
        outputs_batch_np = np.array(outputs_batch)
        last_outputs = np.array([outputs[-1] if len(outputs) > 0 else np.zeros_like(goal_batch_np[0])
                                 for outputs in outputs_batch_np]).squeeze(1)
        print(f"last_outputs:{last_outputs.shape}")
        scaled_diff = (last_outputs - goal_batch_np) * scale_param
        mse_batch = np.mean(np.square(scaled_diff), axis=-1)  

        mse_values.extend(mse_batch.tolist())
        success_count += np.sum(mse_batch < agent.goal_threshold)
        steps_used_list.extend(steps_taken_batch)
        predictions_list.extend(last_outputs)
        treatments_list.extend(actions_batch)
        true_treatments_list.extend([fd['current_treatments'] for fd in future_dict_batch])
        processed_samples.extend(batch_samples)

    total_samples = len(mse_values)
    success_rate = success_count / total_samples if total_samples > 0 else 0
    avg_mse = np.mean(mse_values) if total_samples > 0 else float('inf')
    avg_rmse = np.sqrt(avg_mse)
    avg_steps = np.mean(steps_used_list) if steps_used_list else 0

    metrics = {
        'success_rate': success_rate,
        'avg_mse': avg_mse,
        'avg_rmse': avg_rmse,
        'num_evaluated': total_samples,
        'avg_steps_used': avg_steps,
        'early_stop_rate': 1.0 - (avg_steps / agent.future_length) if steps_used_list else 0
    }

    print(f"CIP评估结果 ({agent.algorithm}):")
    print(f" 成功率: {metrics['success_rate']:.2%}")
    print(f" 平均MSE: {metrics['avg_mse']:.6f}")
    print(f" 平均RMSE: {metrics['avg_rmse']:.6f}")
    print(f" 平均使用步数: {metrics['avg_steps_used']:.2f}/{agent.future_length}")
    print(f" 提前停止率: {metrics['early_stop_rate']:.2%}")
    if hasattr(agent, 'actor'):
        agent.actor.train()
        agent.encoder.train()

    return metrics, complexity_info

def predict_trajectory(agent, history_dict, goal, dataset_collection, config, future_dict):
    """
    使用训练好的模型预测达到目标的轨迹，逐步更新状态
    """
    treatments, outputs, steps_used = agent.generate_treatment_plan(
        history_dict,
        goal,
        dataset_collection,
        future_dict,
        future_length=agent.future_length,
        early_stop=True
    )
    predictions = outputs[-1] if len(outputs) > 0 else None
    if "mimic" in config['dataset']['name']:
        scale_param = dataset_collection.train_f.scaling_params['output_means']
    elif "tumor" in config['dataset']['name']:
        scale_param = dataset_collection.train_scaling_params[1]['cancer_volume']
    else:
        print(f"No dataset named {config['dataset']['name']}!")
        exit()

    goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal
    mse = (((predictions - goal_tensor.cpu().numpy()) * scale_param) ** 2).mean() if predictions is not None else float('inf')

    return predictions, treatments, mse, steps_used
