import numpy as np
import random
import torch
import time
import torch.nn as nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from src.data.her_data_generator import create_history_treatment_goal_samples, convert_dataloader_to_samples

import torch
from fvcore.nn import FlopCountAnalysis
import numpy as np
from src.utils.utils import set_seed
from src.data.cip_dataset import CIPDataset, get_dataloader
from thop import profile, clever_format
# from calflops import get_model_analysis

def count_flops_params(agent, history_dict_batch, goal_batch):
    """
    Calculate the FLOPs and parameter parameters of the SAC_her_Agent model, and strictly refer to the generate_treatment_plan_batch input processing.

    Args:
        agent: SAC_her_Agent instance, including encoder, actor, critic, etc.
        history_dict_batch: List [Dict], each dict contains multiple sets of historical data for the corresponding sample (numpy array format)
        goal_batch: List [np.ndarray or Tensor], corresponding to the target data for each sample

    Pingback:
        total_params: int, total number of parameters
        mflops_per_sample: float, average MFLOPs per sample
    """

    device = DEVICE
    batch_size = len(history_dict_batch)

    #= = = = = = = = = = Refer to the process of generate_treatment_plan_batch, splicing the history dictionary corresponding to the batch = = = = = = = = =
    H_t_batch = {}
    for idx, history_dict in enumerate(history_dict_batch):
        for key in history_dict:
            if key not in H_t_batch:
                H_t_batch[key] = []
            #np.ndarray to tensor
            if isinstance(history_dict[key], np.ndarray):
                H_t_batch[key].append(torch.FloatTensor(history_dict[key]))
            else:
                #Probably tensor is already on cpu/gpu
                H_t_batch[key].append(history_dict[key])

    #Merge all sample dimensions (batch dimension) splicing
    for key in H_t_batch:
        if isinstance(H_t_batch[key][0], torch.Tensor):
            H_t_batch[key] = torch.cat(H_t_batch[key], dim=0).to(device)
        else:
            #Non-tensor remains intact (rarely)
            pass

    #Process goal_batch, turn tensor and stitch
    goal_tensor_batch = []
    for goal in goal_batch:
        if isinstance(goal, np.ndarray):
            goal_tensor_batch.append(torch.FloatTensor(goal).unsqueeze(0))
        elif isinstance(goal, torch.Tensor):
            goal_tensor_batch.append(goal.unsqueeze(0) if goal.dim() == 1 else goal)
        else:
            raise ValueError("goal_batch contains unsupported type")
    goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)  # [batch_size, goal_dim]

    #Calculate encoder FLOPs and parameters
    # flops_encoder.set_op_handle("aten::lstm", lstm_flop_handler)
    macs, params = profile(agent.encoder, (H_t_batch, goal_tensor_batch), verbose=False)
    flops_encoder = macs * 2
    params_encoder = sum(p.numel() for p in agent.encoder.parameters())

    with torch.no_grad():
        encoded_state = agent.encoder(H_t_batch, goal_tensor_batch)

    #Calculate actor FLOPs and parameters
    flops_actor = FlopCountAnalysis(agent.actor, (encoded_state,))
    params_actor = sum(p.numel() for p in agent.actor.parameters())

    #Calculate critic FLOPs and parameters
    with torch.no_grad():
        action, _ = agent.actor(encoded_state)
    flops_critic = FlopCountAnalysis(agent.critic, (encoded_state, action))
    params_critic = sum(p.numel() for p in agent.critic.parameters())

    #possible behavior_policy
    params_behavior = 0
    flops_behavior = 0
    if hasattr(agent, 'behavior_policy'):
        flops_behavior = FlopCountAnalysis(agent.behavior_policy, (encoded_state, action))
        params_behavior = sum(p.numel() for p in agent.behavior_policy.parameters())
        total_flops = flops_encoder + flops_actor.total() + flops_critic.total() + flops_behavior.total()
    else:
        total_flops = flops_encoder + flops_actor.total() + flops_critic.total()
        
    total_params = params_encoder + params_actor + params_critic + params_behavior
    print(f"params_encoder:{params_encoder}, params_actor: {params_actor}, params_critic:{params_critic}, params_behavior:{params_behavior}")
    # exit()
    mflops_per_sample = total_flops / 1e6 / batch_size

    print(f'Total Parameters: {total_params:,}')
    print(f'Total FLOPs: {total_flops:,} FLOPs, {mflops_per_sample:.2f} MFLOPs per sample (batch size={batch_size})')

    return total_params, mflops_per_sample

def evaluate_and_log_case_studies(agent, dataset_collection, config, logger, model_name, max_tau, case_study_results, size=100):
    """
    Run the case study for the specified patient ID and record the detailed trajectory information.
    The data generation logic is consistent with evaluate_cip, using convert_dataloader_to_samples.

    Args:
        agent: A trained model/agent instance.
        dataloader: A data loader for generating samples.
        config: The configuration dictionary for the experiment.
        case_study_ids (List [int]): A list of patient IDs for which a case study is required.
        logger: A logger for logging information.
        model_name (str): The name of the current model (for example, 'gift', 'VCIP').
        max_tau (int): Planning horizon.
        case_study_results (dict): A dictionary used to cumulatively store case study results. The function updates this dictionary.

    Pingback:
        dict: Updated case_study_results dictionary.
    """
    logger.info(f"--- Starting Case Study for model: {model_name} ---")
    start_time = time.time()

    #Set random seeds for reproducibility
    set_seed(config['exp']['seed'])
    config['exp']['tau'] = 6
    #Switch to Evaluation Mode
    if hasattr(agent, 'actor'):
        agent.actor.eval()
    if hasattr(agent, 'encoder'):
        agent.encoder.eval()

    #1. Generate all samples using exactly the same logic as the evaluation function
    data = dataset_collection.val_f.data if not config['exp']['test'] else dataset_collection.test_f.data
    dataloader = get_dataloader(CIPDataset(data, config), batch_size=len(data['outputs']), shuffle=False)
    all_samples = convert_dataloader_to_samples(dataloader)

    print(f"all_samples[0]: {all_samples[0][-1]}")
    #2. Traverse the specified case ID
    for patient_id in range(size):
        try:
            if patient_id >= len(all_samples):
                logger.warning(f"Patient ID {patient_id} is out of bounds for the generated samples ({len(all_samples)}). Skipping.")
                continue

            logger.info(f"Processing patient_id: {patient_id}")

            #Get data from generated sample list
            history_dict, future_dict, goal = all_samples[patient_id]
            goal_np = goal if isinstance(goal, np.ndarray) else goal.cpu().numpy()

            #3. Initialize the patient's result storage structure (if it does not already exist)
            if patient_id not in case_study_results:
                #Initial state is the last point in history
                initial_outcome = history_dict['outputs'][:, -1, :]

                #True trajectory extracted directly from future_dict
                #This is most consistent with the convert_dataloader_to_samples logic
                ground_truth_outcomes = future_dict['outputs'].squeeze()
                
                case_study_results[patient_id] = {
                    'initial_outcome': initial_outcome.squeeze(),
                    'goal': goal_np.squeeze(),
                    'ground_truth_outcomes': np.array(ground_truth_outcomes),
                    'models': {}
                }

            #4. Generate treatment plan and outcome trajectory for current model
            #Note: Here we assume that the agent has a generate_treatment_plan_batch method
            #and it doesn't require dataset_collection, or the information is already inside the agent
            _, outputs_batch, _ = agent.generate_treatment_plan_batch(
                [history_dict], 
                [goal], 
                dataset_collection=dataset_collection, #Pass None explicitly because we don't have this object
                future_dict_batch=[future_dict],
                future_length=max_tau,
                early_stop=False  #For the case study, stop early to get the full trajectory
            )
            
            #outputs_batch is a list and we take the first (and only) element
            outcome_trajectory = outputs_batch[0]

            if patient_id == 0:
                print('-' * 100)
                print(f"evaluate log out: {outcome_trajectory.squeeze()}")
                print(f"evaluate log turth: {ground_truth_outcomes}")

            #5. Store the result trajectory of the current model
            #Convert the result to a numpy array and remove the extra dimensions
            case_study_results[patient_id]['models'][model_name] = np.array(outcome_trajectory).squeeze()
            logger.info(f"Successfully recorded case study for patient {patient_id} with model {model_name}.")

        except IndexError as e:
            logger.error(f"IndexError while processing patient_id {patient_id}: {e}. Skipping.")
        except Exception as e:
            logger.error(f"An unexpected error occurred for patient_id {patient_id}: {e}")
            import traceback
            traceback.print_exc()

    #Resume Training Mode
    if hasattr(agent, 'actor'):
        agent.actor.train()
    if hasattr(agent, 'encoder'):
        agent.encoder.train()

    logger.info(f"--- Case Study for model: {model_name} Finished ---")
    used_time = time.time() - start_time
    return case_study_results, used_time/size


def save_evaluation_results(agent, val_samples, predictions_list, treatments_list, true_treatments_list, output_path='./results/HER'):
    """
    Save the evaluation results to a file, each sample saved separately as a JSON file
    """
    import os
    import json
    from datetime import datetime

    #Create Output Folder
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    absolute_path = os.path.abspath(output_path)
    print(f"Absolute path: {absolute_path}")

    #Creation timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    #Create Results Folder
    if agent:
        results_folder = os.path.join(output_path, f"results_{agent.algorithm}")
    else:
        results_folder = os.path.join(output_path, f"results")
    os.makedirs(results_folder, exist_ok=True)

    #Save index file
    index_data = []

    #Process each sample
    for i, ((history_dict, future_dict, goal), predictions, treatments, true_treatments) in enumerate(
        zip(val_samples, predictions_list, treatments_list, true_treatments_list)):

        sample_id = f"sample_{i}"

        #Create Sample Data
        sample_data = {
            "sample_id": int(sample_id) if hasattr(sample_id, 'item') else sample_id,
            "history_outputs": history_dict['outputs'][0, :, :].astype(float).tolist(),
            "goal_value": [float(x) for x in goal],
            "treatments": treatments.astype(float).tolist(),
            "true_treatments": true_treatments.astype(float).tolist(),
            "final_output": [float(x) for x in predictions[-1]]
        }

        #Saving to File
        sample_file = os.path.join(results_folder, f"{sample_id}.json")
        with open(sample_file, 'w') as f:
            json.dump(sample_data, f, indent=2)

        #Add to index
        index_data.append({
            "sample_id": sample_id,
            "history_length": len(sample_data["history_outputs"]),
            "treatment_length": len(sample_data["treatments"]),
            "goal_value": sample_data["goal_value"],
            "final_output": sample_data["final_output"],
            "file": f"{sample_id}.json"
        })

    #Save index file
    index_file = os.path.join(results_folder, "index.json")
    with open(index_file, 'w') as f:
        json.dump(index_data, f, indent=2)

    print(f"Evaluation results saved to {results_folder}")
    return results_folder

def evaluate_agent(agent, dataset_collection, num_episodes=200):
    """
    Evaluate the agent, use create_history_treatment_goal_samples to generate evaluation samples
    """
    #Evaluation Mode
    if hasattr(agent, 'actor'):
        agent.actor.eval()
        agent.encoder.eval()

    #Generate validation samples using create_history_treatment_goal_samples
    val_samples = create_history_treatment_goal_samples(
        dataset_collection.val_f.data,
        min_history_length=10,
        max_history_length=20,
        future_length=agent.future_length
    )

    #If there are too many samples, randomly select a portion
    if len(val_samples) > num_episodes:
        val_samples = random.sample(val_samples, num_episodes)
    else:
        num_episodes = len(val_samples)

    print(f"Evaluate using {len (val_samples)} validation samples")

    #Evaluation Parameter
    mse_values = []
    success_count = 0
    treatment_similarities = []
    steps_used_list = []

    #Forecast for each evaluation sample
    for history_dict, future_dict, goal in val_samples:
        #Ensure the historical dictionary contains key features
        has_all_keys = True
        for key in ['outputs', 'static_features', 'current_treatments']:
            if key not in history_dict and key == 'outputs' and 'prev_outputs' in history_dict:
                history_dict['outputs'] = history_dict['prev_outputs']
            elif key not in history_dict:
                has_all_keys = False
                break

        if not has_all_keys:
            continue

        #Use progressively updated forecasting methods
        predictions, treatments, mse, steps_used = predict_trajectory(
            agent,
            history_dict,
            goal,
            dataset_collection,
            config
        )

        #Record MSE and steps
        mse_values.append(mse)
        steps_used_list.append(steps_used)

        #Record success (MSE below threshold)
        if mse < agent.goal_threshold:
            success_count += 1

        #Calculate similarity of interventions (with actual future interventions)
        if 'current_treatments' in future_dict:
            actual_treatments = future_dict['current_treatments'].reshape(-1, 2)
            #Ensure consistent length
            min_len = min(len(actual_treatments), len(treatments))
            if min_len > 0:
                actual_treatments = actual_treatments[:min_len]
                pred_treatments = treatments[:min_len]
                #Calculate cosine similarity
                if np.linalg.norm(actual_treatments) > 0 and np.linalg.norm(pred_treatments) > 0:
                    similarity = np.sum(actual_treatments * pred_treatments) / (
                        np.linalg.norm(actual_treatments) * np.linalg.norm(pred_treatments)
                    )
                    treatment_similarities.append(similarity)

    #Reset to Training Mode
    if hasattr(agent, 'actor'):
        agent.actor.train()
        agent.encoder.train()

    #Calculate Metrics
    success_rate = success_count / len(mse_values) if mse_values else 0
    avg_mse = np.mean(mse_values) if mse_values else float('inf')
    avg_rmse = np.sqrt(avg_mse) * scale_param
    avg_similarity = np.mean(treatment_similarities) if treatment_similarities else 0.0
    avg_steps = np.mean(steps_used_list) if steps_used_list else 0

    metrics = {
        'success_rate': success_rate,
        'avg_mse': avg_mse,
        'avg_rmse': avg_rmse,
        'avg_treatment_similarity': avg_similarity,
        'num_evaluated': len(mse_values),
        'avg_steps_used': avg_steps,
        'early_stop_rate': 1.0 - (avg_steps / agent.future_length) if steps_used_list else 0
    }

    print(f"Evaluation results ({agent.algorithm}):")
    print(f"Success rate: {metrics ['success_rate']: .2%}")
    print(f"Average MSE: {metrics ['avg_mse']: .6f}")
    print(f"Average RMSE: {metrics ['avg_rmse']: .6f}")
    print(f"Average intervention similarity: {metrics ['avg_treatment_similarity']: .4f}")
    print(f"Average steps used: {metrics ['avg_steps_used']: .2f}/{agent.future_length}")
    print(f"Early stop rate: {metrics ['early_stop_rate']: .2%}")

    return metrics

def evaluate_cip(agent, dataloader, dataset_collection, config):
    """
    The batch version of the evaluate_cip function avoids the explicit loop in the batch and eliminates the avg_treatment_similarity index calculation
    """
    #Get the number of evaluation samples and batch size
    complexity_info = {}
    num_episodes = config['exp']['eval_episodes']
    for H_t, _ in dataloader:
        batch_size = H_t['outputs'].shape[0]
        break

    if hasattr(agent, 'actor'):
        agent.actor.eval()
        agent.encoder.eval()

    #Generate validation samples using convert_dataloader_to_samples
    val_samples = convert_dataloader_to_samples(dataloader)

    print(f"val_samples[0]: {val_samples[0][-1]}")

    #Downsampling to ensure sample size
    if len(val_samples) > num_episodes:
        val_samples = random.sample(val_samples, num_episodes)
    else:
        num_episodes = len(val_samples)

    print(f"Evaluated with {len (val_samples)} validation samples, batch size: {batch_size}")

    #Filter valid samples to ensure that key fields are present
    valid_samples = []
    for history_dict, future_dict, goal in val_samples:
        has_all_keys = True
        for key in ['outputs', 'static_features', 'current_treatments']:
            if key not in history_dict and not (key == 'outputs' and 'prev_outputs' in history_dict):
                has_all_keys = False
                break
        if has_all_keys:
            #Padding with prev_outputs for missing outputs but with prev_outputs
            if 'outputs' not in history_dict and 'prev_outputs' in history_dict:
                history_dict['outputs'] = history_dict['prev_outputs']
            valid_samples.append((history_dict, future_dict, goal))

    mse_values = []
    success_count = 0
    steps_used_list = []

    predictions_list = []
    treatments_list = []
    true_treatments_list = []
    processed_samples = []

    num_batches = (len(valid_samples) + batch_size - 1) // batch_size

    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(valid_samples))
        batch_samples = valid_samples[start_idx:end_idx]
        if not batch_samples:
            continue

        history_dict_batch = [s[0] for s in batch_samples]
        future_dict_batch = [s[1] for s in batch_samples]
        goal_batch = [s[2] for s in batch_samples]

        if batch_idx == 0 and config['exp']['tau'] == 1:
            complexity_info['params'], complexity_info['mflops'] = count_flops_params(agent, history_dict_batch, goal_batch)

        #Convert target to numpy array batch uniformly
        goal_batch_np = np.stack([g if isinstance(g, np.ndarray) else g.cpu().numpy() for g in goal_batch])

        #Batch Generation of Treatment Plans and Predictive Outputs in Multiple Steps
        actions_batch, outputs_batch, steps_taken_batch = agent.generate_treatment_plan_batch(
            history_dict_batch,
            goal_batch,
            dataset_collection,
            future_dict_batch,
            future_length=agent.future_length,
            early_stop=True
        )

        print(f"evaluate_cip outputs_batch[0]: {outputs_batch[0].squeeze()}")
        print(f"evaluate_cip future_dict_batch: {future_dict_batch[0]['outputs'].squeeze()}")
        
        #Calculate MSE based on last step output
        if "mimic" in config['dataset']['name']:
            scale_param = dataset_collection.train_f.scaling_params['output_means']
        elif "tumor" in config['dataset']['name']:
            scale_param = dataset_collection.train_scaling_params[1]['cancer_volume']
        else:
            scale_param = 1.0

        #Convert outputs_batch to numpy array, shape (batch, steps, output_dim)
        outputs_batch_np = np.array(outputs_batch)
        last_outputs = np.array([outputs[-1] if len(outputs) > 0 else np.zeros_like(goal_batch_np[0])
                                 for outputs in outputs_batch_np]).squeeze(1)
        print(f"last_outputs:{last_outputs.shape}")
        scaled_diff = (last_outputs - goal_batch_np) * scale_param
        mse_batch = np.mean(np.square(scaled_diff), axis=-1)  #Calculated mean square error per sample

        mse_values.extend(mse_batch.tolist())
        success_count += np.sum(mse_batch < agent.goal_threshold)
        steps_used_list.extend(steps_taken_batch)

        #Result Storage
        predictions_list.extend(last_outputs)
        treatments_list.extend(actions_batch)
        true_treatments_list.extend([fd['current_treatments'] for fd in future_dict_batch])
        processed_samples.extend(batch_samples)

    total_samples = len(mse_values)
    success_rate = success_count / total_samples if total_samples > 0 else 0
    avg_mse = np.mean(mse_values) if total_samples > 0 else float('inf')
    avg_rmse = np.sqrt(avg_mse)
    avg_steps = np.mean(steps_used_list) if steps_used_list else 0

    metrics = {
        'success_rate': success_rate,
        'avg_mse': avg_mse,
        'avg_rmse': avg_rmse,
        'num_evaluated': total_samples,
        'avg_steps_used': avg_steps,
        'early_stop_rate': 1.0 - (avg_steps / agent.future_length) if steps_used_list else 0
    }

    print(f"CIP Assessment Results ({agent.algorithm}):")
    print(f"Success rate: {metrics ['success_rate']: .2%}")
    print(f"Average MSE: {metrics ['avg_mse']: .6f}")
    print(f"Average RMSE: {metrics ['avg_rmse']: .6f}")
    print(f"Average steps used: {metrics ['avg_steps_used']: .2f}/{agent.future_length}")
    print(f"Early stop rate: {metrics ['early_stop_rate']: .2%}")

    #Switch back to training mode
    if hasattr(agent, 'actor'):
        agent.actor.train()
        agent.encoder.train()

    return metrics, complexity_info

def predict_trajectory(agent, history_dict, goal, dataset_collection, config, future_dict):
    """
    Use the trained model to predict the trajectory to reach the goal and gradually update the status
    """
    #Generate treatment plan while progressively updating history
    treatments, outputs, steps_used = agent.generate_treatment_plan(
        history_dict,
        goal,
        dataset_collection,
        future_dict,
        future_length=agent.future_length,
        early_stop=True
    )

    #Final prediction is the output of the last step
    predictions = outputs[-1] if len(outputs) > 0 else None

    #Calculate MSE
    if "mimic" in config['dataset']['name']:
        scale_param = dataset_collection.train_f.scaling_params['output_means']
    elif "tumor" in config['dataset']['name']:
        scale_param = dataset_collection.train_scaling_params[1]['cancer_volume']
    else:
        print(f"No dataset named {config['dataset']['name']}!")
        exit()

    goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal
    mse = (((predictions - goal_tensor.cpu().numpy()) * scale_param) ** 2).mean() if predictions is not None else float('inf')

    return predictions, treatments, mse, steps_used
