import numpy as np
import random
from collections import deque
import time
import matplotlib.pyplot as plt
import torch
import os

def create_single_sample(data, patient_idx, split_time, history_length, future_length, time_keys, static_keys):
    """
    Create a sample for a single patient at a specified point in time
    
    Args:
        data: Data Dictionary
        patient_idx: Patient index
        split_time: historical and future split points in time
        history_length: history length (use full history if None)
        future_length: Future length
        time_keys: list of keys with time dimension
        static_keys: list of keys for static features
    
    Pingback:
        (history_dict, future_dict, goal): Sample Ternary
    """
    #Create Historical Dictionary
    history_dict = {}
    #Create Future Dictionary
    future_dict = {}
    
    #For each key with a time dimension
    for key in time_keys:
        #Splitting History and Future
        if history_length is None:
            #Use all history
            history_dict[key] = data[key][patient_idx:patient_idx+1, :split_time].copy()
        else:
            #Use history of specified length
            start_time = max(0, split_time - history_length)
            history_dict[key] = data[key][patient_idx:patient_idx+1, start_time:split_time].copy()
        
        future_dict[key] = data[key][patient_idx:patient_idx+1, split_time:split_time+future_length].copy()
    
    #For each static key
    for key in static_keys:
        #Duplicate       <g id="1">      </g>Directly
        history_dict[key] = data[key][patient_idx:patient_idx+1].copy() if len(data[key].shape) == 1 else data[key][patient_idx:patient_idx+1].copy()
        future_dict[key] = history_dict[key].copy()  #Static features are the same in history and in the future
    
    #Goal - Output of the last moment in the future
    goal = data['outputs'][patient_idx, split_time+future_length-1].copy()
    
    return (history_dict, future_dict, goal)


def get_patient_sequence_length(data, patient_idx):
    """
    Get valid sequence length for a given patient
    
    Args:
        data: Data Dictionary
        patient_idx: Patient index
    
    Pingback:
        seq_length: valid sequence length
    """
    if 'active_entries' in data:
        #Calculate number of active entries (number of 1 in active_entries)
        active_indices = np.where(data['active_entries'][patient_idx, :, 0] > 0)[0]
        if len(active_indices) > 0:
            return active_indices[-1] + 1  #Index of last active entry +1
        else:
            return 0  #No active entries
    else:
        #If no active_entries, use sequence_lengths or default length
        time_keys = [key for key in data.keys() if len(data[key].shape) >= 2 and data[key].shape[1] > 1]
        if 'sequence_lengths' in data:
            return data['sequence_lengths'][patient_idx]
        elif time_keys:
            return data[time_keys[0]].shape[1]
        else:
            return 0


def create_history_treatment_goal_samples(data, min_history_length=15, max_history_length=30, 
                                        future_length=5, use_tail=False):
    """
    Create a (historical H, future F, target) triplet sample from the data, using variable historical and future lengths
    Args:
        data: data dictionary with multiple keys
        min_history_length: Minimum history length
        max_history_length: maximum history length
    Pingback:
        samples: triple sample list [(history_dict, future_dict, goal),...]
    """
    samples = []
    
    #Total patients
    num_patients = data['active_entries'].shape[0]
    
    #Determine which keys have a time dimension (dimension > = 2)
    time_keys = []
    static_keys = []
    
    for key in data.keys():
        if len(data[key].shape) >= 2 and data[key].shape[1] > 1:  #Has time dimension
            time_keys.append(key)
        else:  #No time dimension or dimension is 1
            static_keys.append(key)
    
    print(f"Time Related Keys: {time_keys}")
    print(f"Static keys: {static_keys}")
    
    #Create samples for each patient
    for i in range(num_patients):
        #Calculate Sequence Length
        seq_length = get_patient_sequence_length(data, i)
        if seq_length == 0:
            continue  #Skip patients with no active entries
        
        if use_tail:
            #Legacy Logic: Fixed Future Length
            for t in range(max_history_length, seq_length - 1):
                current_future_length = min(seq_length - t, future_length)
                sample = create_single_sample(data, i, t, None, current_future_length, time_keys, static_keys)
                samples.append(sample)
        else:
            #Legacy Logic: Fixed Future Length
            for t in range(max_history_length, seq_length - future_length):
                sample = create_single_sample(data, i, t, None, future_length, time_keys, static_keys)
                samples.append(sample)
    
    print(f"{len (samples)} (historical, future, target) samples created")
    
    #Validate Sample Format
    if samples:
        history, future, goal = samples[0]
        print("\ n Example Sample Structure:")
        print(f"Keys included in history data: {list (history.keys ())}")
        for key in history:
            if isinstance(history[key], np.ndarray):
                print(f"{key} Shape: {history [key] .shape}")
            else:
                print(f"{key} Type: {type (history [key])}")
        
        print(f"Keys included in future data: {list (future.keys ())}")
        for key in future:
            if isinstance(future[key], np.ndarray):
                print(f"{key} Shape: {future [key] .shape}")
            else:
                print(f"{key} Type: {type (future [key])}")
        
        if goal is not None:
            print(f"Goal shape: {goal.shape if isinstance (goal, np.ndarray) else type (goal)}")
        else:
            print("Goal: None")
    
    return samples

def convert_dataloader_to_samples(dataloader):
    """
    Convert data in dataloader to (historical H, future F, target) triple sample
    
    Args:
        dataloader: DataLoader object created with CIPDataset
    
    Pingback:
        samples: triple sample list [(history_dict, future_dict, goal),...]
    """
    samples = []
    time_keys = None
    static_keys = None

    all_last_outputs = []

    output_dir = './results/yvalues'
    os.makedirs(output_dir, exist_ok=True)
    #1. Traverse the dataloader to collect the last moment outputs of each batch
    for i, batch in enumerate(dataloader):
        H_t, targets = batch

        # targets['outputs']: (batch_size, seq_len, feat_dim)
        last_outputs = targets['outputs'][:, -1, :]       # -> (batch_size, feat_dim)
        last_outputs = last_outputs.detach().cpu()        #Disassemble the calculation diagram and move it to the CPU

        all_last_outputs.append(last_outputs)

    #2. Splice all batch
    all_last_outputs = torch.cat(all_last_outputs, dim=0)  #- > (N, feat_dim), N = total number of samples

    #3. Calculate overall mean
    mean_value = all_last_outputs.mean().item()

    #4. Write file
    out_path = os.path.join(output_dir, 'overall_mean.txt')
    with open(out_path, 'w') as f:
        f.write(f"{mean_value:.6f}")

    print(f"All outputs last-step mean = {mean_value:.6f}, saved to {out_path}")

    
    for batch in dataloader:
        #batch is a tuple containing (history_dict, target_dict)
        H_batch, target_batch = batch
        
        #Determine which keys are time-dependent and which are static (just execute on the first batch)
        if time_keys is None or static_keys is None:
            time_keys = []
            static_keys = []
            
            for key in H_batch:
                #Check if the tensor has a time dimension (dimension > 2 and second dimension > 1)
                if isinstance(H_batch[key], torch.Tensor) and len(H_batch[key].shape) >= 2:
                    time_keys.append(key)
                else:
                    static_keys.append(key)
            
            print(f"Time Related Keys: {time_keys}")
            print(f"Static keys: {static_keys}")
        
        #Number of samples in batch
        batch_size = next(iter(H_batch.values())).shape[0]
        
        #Process each sample
        for i in range(batch_size):
            #Create Historical Dictionary
            history_dict = {}
            
            #Create Future Dictionary
            future_dict = {}
            
            #Processing time correlation key
            for key in time_keys:
                if key in H_batch:
                    history_dict[key] = H_batch[key][i:i+1].cpu().numpy()
                if key in target_batch:
                    future_dict[key] = target_batch[key][i:i+1].cpu().numpy()
            
            #Handling Static Keys
            for key in static_keys:
                if key in H_batch:
                    #Static features are the same in history and in the future
                    history_dict[key] = H_batch[key][i:i+1].cpu().numpy() if isinstance(H_batch[key], torch.Tensor) else H_batch[key][i:i+1]
                    future_dict[key] = history_dict[key].copy()
            
            #Goal - Output of the last moment in the future
            if 'outputs' in target_batch:
                goal = target_batch['outputs'][i, -1].cpu().numpy()
            else:
                #If there is no explicit outputs key, use a different key as the target
                goal_key = 'cancer_volume' if 'cancer_volume' in target_batch else next(iter(target_batch.keys()))
                goal = target_batch[goal_key][i, -1].cpu().numpy()
            
            #Add Sample
            samples.append((history_dict, future_dict, goal))
    
    print(f"{len (samples)} (historical, future, target) samples created")
    
    #Validate Sample Format
    if samples:
        history, future, goal = samples[0]
        print(f"samples[0]:{history['outputs']}")
        print("\ n Example Sample Structure:")
        print(f"Keys included in history data: {list (history.keys ())}")
        for key in history:
            if isinstance(history[key], np.ndarray):
                print(f"{key} Shape: {history [key] .shape}")
            else:
                print(f"{key} Type: {type (history [key])}")
        
        # print(f"Keys included in future data: {list (future.keys ())}")
        # for key in future:
        #     if isinstance(future[key], np.ndarray):
        #         print(f"{key} Shape: {future [key] .shape}")
        #     else:
        #         print(f"{key} Type: {type (future [key])}")
        
        # if goal is not None:
        #     print(f"Goal shape: {goal.shape if isinstance (goal, np.ndarray) else type (goal)}")
        # else:
        #     print("Goal: None")
    
    return samples

