from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import yaml
import torch
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from src.data import MIMIC3SyntheticDatasetCollection
from src.data.mimic_iii.semi_synthetic_dataset import simulate_output_after_actions


with open('config/dataset/mimic3_synthetic.yaml', 'r') as f:
    config = yaml.safe_load(f)
dataset_config = config['dataset']
path = dataset_config['path']

vital_list = dataset_config['vital_list']
static_list = dataset_config['static_list']
synth_outcomes_list = []
for outcome_config in dataset_config['synth_outcomes_list']:
    outcome = instantiate(outcome_config)
    synth_outcomes_list.append(outcome)
synth_treatments_list = []
for treatment_config in dataset_config['synth_treatments_list']:
    treatment = instantiate(treatment_config)
    synth_treatments_list.append(treatment)

treatment_outcomes_influence = dataset_config['treatment_outcomes_influence']

min_seq_length = dataset_config['min_seq_length']
max_seq_length = dataset_config['max_seq_length']
max_number = dataset_config['max_number']
split = dataset_config['split']
projection_horizon = dataset_config['projection_horizon']
n_treatments_seq = dataset_config['n_treatments_seq']
dataset_collection = MIMIC3SyntheticDatasetCollection(
    path=path,
    synth_outcomes_list=synth_outcomes_list,
    synth_treatments_list=synth_treatments_list,
    treatment_outcomes_influence=treatment_outcomes_influence,
    min_seq_length=min_seq_length,
    max_seq_length=max_seq_length,
    max_number=max_number,
    seed=100,
    data_seed=100,
    split=split,
    projection_horizon=projection_horizon,
    n_treatments_seq=n_treatments_seq
)
batch_size = 64  
train_loader = DataLoader(dataset_collection.train_f, batch_size=batch_size, shuffle=False)
vital_name_to_idx = {name: idx for idx, name in enumerate(vital_list)}
static_name_to_idx = {name: idx for idx, name in enumerate(static_list)}
all_actual_outcomes = []
all_predicted_outcomes = []
for batch_idx, batch_data in enumerate(train_loader):
    print(f"Processing batch {batch_idx+1}/{len(train_loader)}")
    Ht = {
        'sequence_lengths': batch_data['sequence_lengths'],
        'prev_treatments': batch_data['prev_treatments'],
        'vitals': batch_data['vitals'],
        'current_treatments': batch_data['current_treatments'],
        'unscaled_outputs': batch_data['unscaled_outputs'],
        'static_features': batch_data['static_features'],
        'active_entries': batch_data['active_entries']
    }
    batch_predicted_outcomes = simulate_output_after_actions(
        Ht, 
        synth_outcomes_list, 
        synth_treatments_list, 
        treatment_outcomes_influence, 
        projection_horizon,
        vital_name_to_idx,     
        static_name_to_idx   
    )
    batch_actual_outcomes = []
    for i in range(len(batch_data['sequence_lengths'])):
        seq_len = int(batch_data['sequence_lengths'][i])
        actual_outcome = batch_data['unscaled_outputs'][i, seq_len-1]
        batch_actual_outcomes.append(actual_outcome)
    
    batch_actual_outcomes = torch.stack(batch_actual_outcomes)
    all_predicted_outcomes.append(batch_predicted_outcomes.cpu().numpy())
    all_actual_outcomes.append(batch_actual_outcomes.cpu().numpy())
actual_outcomes = np.vstack(all_actual_outcomes)
predicted_outcomes = np.vstack(all_predicted_outcomes)
print(f"predicted_outcomes: {predicted_outcomes}")
print(f"actual_outcomes: {actual_outcomes}")

if actual_outcomes.size > 0:
    mse = np.mean((predicted_outcomes - actual_outcomes) ** 2)
    print(f"Mean Squared Error: {mse}")