import numpy as np
import random
from collections import deque
import time
import matplotlib.pyplot as plt
import torch
import os

def create_single_sample(data, patient_idx, split_time, history_length, future_length, time_keys, static_keys):
    """
    为单个患者在指定时间点创建一个样本
    
    参数:
        data: 数据字典
        patient_idx: 患者索引
        split_time: 历史和未来的分割时间点
        history_length: 历史长度（如果为None则使用全部历史）
        future_length: 未来长度
        time_keys: 包含时间维度的键列表
        static_keys: 静态特征的键列表
    
    返回:
        (history_dict, future_dict, goal): 样本三元组
    """
    history_dict = {}
    future_dict = {}
    for key in time_keys:
        if history_length is None:
            history_dict[key] = data[key][patient_idx:patient_idx+1, :split_time].copy()
        else:
            start_time = max(0, split_time - history_length)
            history_dict[key] = data[key][patient_idx:patient_idx+1, start_time:split_time].copy()
        
        future_dict[key] = data[key][patient_idx:patient_idx+1, split_time:split_time+future_length].copy()
    for key in static_keys:
        history_dict[key] = data[key][patient_idx:patient_idx+1].copy() if len(data[key].shape) == 1 else data[key][patient_idx:patient_idx+1].copy()
        future_dict[key] = history_dict[key].copy()  
    goal = data['outputs'][patient_idx, split_time+future_length-1].copy()
    
    return (history_dict, future_dict, goal)


def get_patient_sequence_length(data, patient_idx):
    """
    获取指定患者的有效序列长度
    
    参数:
        data: 数据字典
        patient_idx: 患者索引
    
    返回:
        seq_length: 有效序列长度
    """
    if 'active_entries' in data:
        active_indices = np.where(data['active_entries'][patient_idx, :, 0] > 0)[0]
        if len(active_indices) > 0:
            return active_indices[-1] + 1  
        else:
            return 0  
    else:
        time_keys = [key for key in data.keys() if len(data[key].shape) >= 2 and data[key].shape[1] > 1]
        if 'sequence_lengths' in data:
            return data['sequence_lengths'][patient_idx]
        elif time_keys:
            return data[time_keys[0]].shape[1]
        else:
            return 0


def create_history_treatment_goal_samples(data, min_history_length=15, max_history_length=30, 
                                        future_length=5, use_tail=False):
    """
    从数据中创建 (历史H, 未来F, 目标) 三元组样本，使用可变历史长度和未来长度
    参数:
        data: 数据字典，包含多个键
        min_history_length: 最小历史长度
        max_history_length: 最大历史长度
    返回:
        samples: 三元组样本列表 [(history_dict, future_dict, goal), ...]
    """
    samples = []
    num_patients = data['active_entries'].shape[0]
    time_keys = []
    static_keys = []
    
    for key in data.keys():
        if len(data[key].shape) >= 2 and data[key].shape[1] > 1:  
            time_keys.append(key)
        else:  
            static_keys.append(key)
    
    print(f"时间相关键: {time_keys}")
    print(f"静态键: {static_keys}")
    for i in range(num_patients):
        seq_length = get_patient_sequence_length(data, i)
        if seq_length == 0:
            continue  
        
        if use_tail:
            for t in range(max_history_length, seq_length - 1):
                current_future_length = min(seq_length - t, future_length)
                sample = create_single_sample(data, i, t, None, current_future_length, time_keys, static_keys)
                samples.append(sample)
        else:
            for t in range(max_history_length, seq_length - future_length):
                sample = create_single_sample(data, i, t, None, future_length, time_keys, static_keys)
                samples.append(sample)
    
    print(f"创建了 {len(samples)} 个 (历史, 未来, 目标) 样本")
    if samples:
        history, future, goal = samples[0]
        print("\n示例样本结构:")
        print(f"历史数据包含的键: {list(history.keys())}")
        for key in history:
            if isinstance(history[key], np.ndarray):
                print(f"  {key} 形状: {history[key].shape}")
            else:
                print(f"  {key} 类型: {type(history[key])}")
        
        print(f"未来数据包含的键: {list(future.keys())}")
        for key in future:
            if isinstance(future[key], np.ndarray):
                print(f"  {key} 形状: {future[key].shape}")
            else:
                print(f"  {key} 类型: {type(future[key])}")
        
        if goal is not None:
            print(f"目标形状: {goal.shape if isinstance(goal, np.ndarray) else type(goal)}")
        else:
            print("目标: None")
    
    return samples

def convert_dataloader_to_samples(dataloader):
    """
    将dataloader中的数据转换为(历史H, 未来F, 目标)三元组样本
    
    参数:
        dataloader: 使用CIPDataset创建的DataLoader对象
    
    返回:
        samples: 三元组样本列表 [(history_dict, future_dict, goal), ...]
    """
    samples = []
    time_keys = None
    static_keys = None

    all_last_outputs = []

    output_dir = './results/yvalues'
    os.makedirs(output_dir, exist_ok=True)
    for i, batch in enumerate(dataloader):
        H_t, targets = batch
        last_outputs = targets['outputs'][:, -1, :]       
        last_outputs = last_outputs.detach().cpu()        

        all_last_outputs.append(last_outputs)
    all_last_outputs = torch.cat(all_last_outputs, dim=0)  
    mean_value = all_last_outputs.mean().item()
    out_path = os.path.join(output_dir, 'overall_mean.txt')
    with open(out_path, 'w') as f:
        f.write(f"{mean_value:.6f}")

    print(f"All outputs last-step mean = {mean_value:.6f}, saved to {out_path}")

    
    for batch in dataloader:
        H_batch, target_batch = batch
        if time_keys is None or static_keys is None:
            time_keys = []
            static_keys = []
            
            for key in H_batch:
                if isinstance(H_batch[key], torch.Tensor) and len(H_batch[key].shape) >= 2:
                    time_keys.append(key)
                else:
                    static_keys.append(key)
            
            print(f"时间相关键: {time_keys}")
            print(f"静态键: {static_keys}")
        batch_size = next(iter(H_batch.values())).shape[0]
        for i in range(batch_size):
            history_dict = {}
            future_dict = {}
            for key in time_keys:
                if key in H_batch:
                    history_dict[key] = H_batch[key][i:i+1].cpu().numpy()
                if key in target_batch:
                    future_dict[key] = target_batch[key][i:i+1].cpu().numpy()
            for key in static_keys:
                if key in H_batch:
                    history_dict[key] = H_batch[key][i:i+1].cpu().numpy() if isinstance(H_batch[key], torch.Tensor) else H_batch[key][i:i+1]
                    future_dict[key] = history_dict[key].copy()
            if 'outputs' in target_batch:
                goal = target_batch['outputs'][i, -1].cpu().numpy()
            else:
                goal_key = 'cancer_volume' if 'cancer_volume' in target_batch else next(iter(target_batch.keys()))
                goal = target_batch[goal_key][i, -1].cpu().numpy()
            samples.append((history_dict, future_dict, goal))
    
    print(f"创建了 {len(samples)} 个 (历史, 未来, 目标) 样本")
    if samples:
        history, future, goal = samples[0]
        print(f"samples[0]:{history['outputs']}")
        print("\n示例样本结构:")
        print(f"历史数据包含的键: {list(history.keys())}")
        for key in history:
            if isinstance(history[key], np.ndarray):
                print(f"  {key} 形状: {history[key].shape}")
            else:
                print(f"  {key} 类型: {type(history[key])}")
    
    return samples

