import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import os
import pickle
import random
from src.data.cip_dataset import CIPDataset, get_dataloader
from src.utils.utils import set_seed
from src.utils.helper_functions import generate_perturbed_sequences
from src.data.her_data_generator import create_history_treatment_goal_samples, convert_dataloader_to_samples
from src.gift.utils.evaluator import save_evaluation_results
from scipy.stats import spearmanr
import time
import pickle

class BaseCausalModel:
    def generate_treatment_plan_batch(self, history_dict_batch, goal_batch, dataset_collection, future_dict_batch, future_length=None, early_stop=True, sample=5, eval_simulation=False, treatments_load=None, return_model_predictions=False, **kwargs):
        """
        Bulk version treatment plan generation that supports both strategies.
        * * * New: Controls whether the model's internal prediction trajectory is returned via the return_model_predictions parameter. * * *
        """
        
        if future_length is None:
            future_length = self.tau
        
        device = 'cuda'
        batch_size = len(history_dict_batch)

                #Get configuration parameters
        strategy = getattr(self.config.get('exp', {}), 'action_selection_strategy', 'random')
        
        #encoder Specialization
        encoder = self.encoder if hasattr(self, 'encoder') else None
        
        #= = = = = = = = = = Batch History Structure = = = = = = = = = = = = = =
        H_t_batch = {}
        for batch_idx, history_dict in enumerate(history_dict_batch):
            for key in history_dict:
                if key not in H_t_batch:
                    H_t_batch[key] = []
                if isinstance(history_dict[key], np.ndarray):
                    H_t_batch[key].append(torch.FloatTensor(history_dict[key]))
                else:
                    H_t_batch[key].append(history_dict[key])
        
        #Merge on batch dimension
        for key in H_t_batch:
            H_t_batch[key] = torch.cat(H_t_batch[key], dim=0).to(device)

        current_treatments = []
        for batch_idx, history_dict in enumerate(future_dict_batch):
            if isinstance(history_dict['current_treatments'], np.ndarray):
                current_treatments.append(torch.FloatTensor(future_dict_batch[batch_idx]['current_treatments']))
            else:
                current_treatments.append(future_dict_batch[batch_idx]['current_treatments'])
        current_treatments = torch.cat(current_treatments, dim=0).to(device)

        
        
        #Process goal_batch
        goal_tensor_batch = []
        goal_np_batch = []
        for goal in goal_batch:
            if isinstance(goal, np.ndarray):
                goal_tensor_batch.append(torch.FloatTensor(goal).unsqueeze(0))
                goal_np_batch.append(goal)
            else:
                goal_tensor_batch.append(goal.unsqueeze(0) if goal.dim() == 1 else goal)
                goal_np_batch.append(goal.cpu().numpy())
        goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)
        
        #Copy updated_history_batch
        updated_history_batch = []
        for history_dict in history_dict_batch:
            updated_history = {}
            for key in history_dict:
                if isinstance(history_dict[key], np.ndarray):
                    updated_history[key] = history_dict[key].copy()
                else:
                    updated_history[key] = history_dict[key].cpu().numpy().copy() if hasattr(history_dict[key], 'cpu') else history_dict[key]
            updated_history_batch.append(updated_history)
        
        current_H_t_batch = H_t_batch
        actions_batch = [[] for _ in range(batch_size)]
        outputs_batch = [[] for _ in range(batch_size)]
        steps_taken_batch = [future_length] * batch_size
        action_dim = H_t_batch['current_treatments'].shape[-1]
        
        model_predictions_batch = [[] for _ in range(batch_size)] if return_model_predictions else None
        
        #The Loop
        for t in range(future_length):
            #Choose action based on policy
            if strategy == 'gradient_descent':
                dict_on_cuda = all(t.is_cuda for t in current_H_t_batch.values())
                # print(f"dict_on_cuda:{dict_on_cuda}, goal_tensor_batch:{goal_tensor_batch.is_cuda}, self.device:{self.device}")
                best_actions = self._select_actions_gradient_descent(
                    current_H_t_batch, goal_tensor_batch, action_dim, sample, encoder, **kwargs
                )
            else:  #Default Random Policy
                best_actions = self._select_actions_random_sampling(
                    current_H_t_batch, goal_tensor_batch, action_dim, sample, encoder, **kwargs
                )
            
            if eval_simulation:
                best_actions = current_treatments[:, t:t+1, :]
                if treatments_load is not None:
                    treatments_load = treatments_load.to(device)
                    best_actions = treatments_load[:, t:t+1, :]

            #* * * New: Conditionally record internal predictions for models * * *
            if return_model_predictions and hasattr(self, 'get_predictions_after_tau_steps'):
                with torch.no_grad():
                    model_pred_at_step = self.get_predictions_after_tau_steps(current_H_t_batch, None, best_actions, encoder)
                    for i in range(batch_size):
                        model_predictions_batch[i].append(model_pred_at_step[i].cpu().numpy())

            #... (the rest of the code inside the loop remains unchanged)...
            for i in range(batch_size):
                best_action = best_actions[i:i+1]
                action_np = best_action.cpu().numpy()
                actions_batch[i].append(action_np)
                
                #Calculate cumulative actions
                if len(actions_batch[i]) == 1:
                    actions_tensor = actions_batch[i][0]
                else:
                    actions_tensor = np.concatenate(actions_batch[i], axis=1)
                actions_tensor = torch.FloatTensor(actions_tensor).to(device)
                
                #Analog output
                output = dataset_collection.val_f.simulate_output_after_actions(
                    history_dict_batch[i], actions_tensor, dataset_collection.train_scaling_params
                )
                outputs_batch[i].append(output)
                
                #Early stop check
                if early_stop and np.linalg.norm(output - goal_np_batch[i]) < self.config.dataset.goal_threshold * 0.001:
                    if steps_taken_batch[i] > t + 1:
                        steps_taken_batch[i] = t + 1
                
                #Update history (if not the last step)
                if t < future_length - 1:
                    updated_history_batch[i] = self._update_patient_history(
                        updated_history_batch[i], action_np, output
                    )
            
            #Prepare current_H_t_batch for the next iteration
            if t < future_length - 1:
                current_H_t_batch = {}
                for key in updated_history_batch[0]:
                    batch_data = []
                    for i in range(batch_size):
                        if isinstance(updated_history_batch[i][key], np.ndarray):
                            batch_data.append(torch.FloatTensor(updated_history_batch[i][key]))
                        else:
                            batch_data.append(updated_history_batch[i][key])
                    current_H_t_batch[key] = torch.cat(batch_data, dim=0).to(device)
        
        #Convert Output Format
        actions_batch = [np.array(actions) if actions else np.array([]) for actions in actions_batch]
        outputs_batch = [np.array(outputs) if outputs else np.array([]) for outputs in outputs_batch]
        
        if encoder is not None:
            encoder.train()
        self.train()
        
        if return_model_predictions:
            model_predictions_batch = [np.array(preds) if preds else np.array([]) for preds in model_predictions_batch]
            return actions_batch, outputs_batch, steps_taken_batch, model_predictions_batch
        else:
            return actions_batch, outputs_batch, steps_taken_batch

    def _select_actions_random_sampling(self, current_H_t_batch, goal_tensor_batch, action_dim, sample, encoder, **kwargs):
        """Random sampling strategy selection optimal action"""
        batch_size = current_H_t_batch['current_treatments'].shape[0]
        device = current_H_t_batch['current_treatments'].device
        
        with torch.no_grad():
            candidate_actions_list = []
            candidate_scores_list = []
            
            for _ in range(sample):
                #Generate random actions for the entire batch
                random_action = torch.rand(batch_size, 1, action_dim).to(device)
                candidate_actions_list.append(random_action)
                #Batch Calculation Score
                if hasattr(self, '_candidate_action_objective'):
                    score = self._candidate_action_objective(
                        current_H_t_batch, goal_tensor_batch.cpu().numpy(), 
                        random_action, encoder=encoder, **kwargs
                    )
                    if isinstance(score, torch.Tensor):
                        score = score.cpu().numpy()
                    elif not isinstance(score, (list, np.ndarray)):
                        score = [score] * batch_size
                else:
                    pred = self.get_predictions_after_tau_steps(current_H_t_batch, None, random_action, encoder)
                    score = torch.norm(pred - goal_tensor_batch, dim=-1).cpu().numpy()
                
                candidate_scores_list.append(score)
            
            #Select the best action for each sample
            best_actions = []
            for i in range(batch_size):
                sample_scores = [scores[i] if isinstance(scores, (list, np.ndarray)) else scores for scores in candidate_scores_list]
                best_sample_idx = int(np.argmin(sample_scores))
                best_action = candidate_actions_list[best_sample_idx][i:i+1]  # (1, 1, action_dim)
                best_actions.append(best_action)
            
            return torch.cat(best_actions, dim=0)

    def _select_actions_gradient_descent(self, current_H_t_batch, goal_tensor_batch, action_dim, gradient_steps, encoder, **kwargs):
        """Gradient descent strategy optimization action (using optimizer version)"""
        batch_size = current_H_t_batch['current_treatments'].shape[0]
        device = current_H_t_batch['current_treatments'].device
        
        #Get Hyperparameters
        lr = getattr(self.config.get('exp', {}), 'action_learning_rate', 0.01)
        #Initialize learnable action parameters
        actions = torch.rand(batch_size, 1, action_dim, device=device, requires_grad=True)
        optimizer = torch.optim.Adam([actions], lr=lr)  #Use Adam by default
        # print(f"initial actions:{actions[0].squeeze()}")
        #Gradient descent optimization
        for step in range(gradient_steps):
            optimizer.zero_grad()
            
            #Ensure actions are within reasonable limits [0, 1]
            actions_clamped = torch.sigmoid(actions)
            
            #Calculate objective function
            if hasattr(self, '_candidate_action_objective'):
                #If there is a dedicated tensor version of the objective function
                loss = self._candidate_action_objective(
                    current_H_t_batch, goal_tensor_batch.cpu().numpy(), actions_clamped, encoder=encoder, reduce=True, **kwargs
                )
            else:
                #Use prediction error as loss
                pred = self.get_predictions_after_tau_steps(current_H_t_batch, None, actions_clamped, encoder)
                loss = torch.norm(pred - goal_tensor_batch, dim=-1).mean()
            
            #Backpropagation and optimization
            loss.backward()
            optimizer.step()
        
        return actions_clamped.detach()

    def generate_treatment_plan(self, history_dict, goal, dataset_collection, future_dict_batch, future_length=None, early_stop=True, sample=5, **kwargs):
        """Keep the original interface unchanged"""
        actions_batch, outputs_batch, steps_taken_batch = self.generate_treatment_plan_batch(
            [history_dict], [goal], dataset_collection, future_dict_batch, future_length, early_stop, sample, **kwargs
        )
        return actions_batch[0], outputs_batch[0], steps_taken_batch[0]

    def evaluate(self, dataset_collection, config, max_tau=6, num_episodes=2000, sample=30, logger=None, **kwargs):
        """Simplified batch evaluation method"""
        set_seed(self.config['exp']['seed'])
        self.to('cuda')
        orig_tau = self.tau
        all_metrics = {}

        tau_values = range(1, max_tau + 1) if max_tau is not None else [self.tau]

        if 'tumor' in getattr(self.config, 'dataset', {}).get('name', ''):
            scaling_factor = dataset_collection.train_scaling_params[1]['cancer_volume']
        elif "mimic" in config['dataset']['name']:
            scaling_factor = dataset_collection.train_f.scaling_params['output_means']
        logger.info(f"scaling_factor:{scaling_factor.shape}, {scaling_factor}")
        for current_tau in tau_values:
            self.config['exp']['tau'] = current_tau

            try:
                data = dataset_collection.val_f.data_original if not self.config['exp']['test'] else dataset_collection.test_f.data_original
            except:
                data = dataset_collection.val_f.data if not self.config['exp']['test'] else dataset_collection.test_f.data

            if 'mimic' in self.config['dataset']['name']:
                if self.config['exp']['test']:
                    batch_size = int(self.config['dataset']['max_number'] * self.config['dataset']['split']['test'])
                else:
                    batch_size = int(self.config['dataset']['max_number'] * self.config['dataset']['split']['val'])
            elif 'tumor' in self.config['dataset']['name']:
                if self.config['exp']['test']:
                    batch_size = self.config['dataset']['num_patients']['test']
                else:
                    batch_size = self.config['dataset']['num_patients']['val']

            dataloader = get_dataloader(CIPDataset(data, self.config), batch_size=batch_size, shuffle=False)
            val_samples = convert_dataloader_to_samples(dataloader)

            #Filter valid samples
            valid_samples = []
            for history_dict, future_dict, goal in val_samples:
                if self._validate_sample_keys(history_dict):
                    valid_samples.append((history_dict, future_dict, goal))

            mse_values = []
            success_count = 0
            treatment_similarities = []
            steps_used_list = []
            predictions_list, treatments_list, processed_samples, true_treatments_list = [], [], [], []


            #Bulk Smush
            num_batches = (len(valid_samples) + batch_size - 1) // batch_size
            
            for batch_idx in range(num_batches):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(valid_samples))
                batch_samples = valid_samples[start_idx:end_idx]
                
                if not batch_samples:
                    continue
                
                #Organization Batch Data
                history_dict_batch = [sample[0] for sample in batch_samples]
                future_dict_batch = [sample[1] for sample in batch_samples]
                goal_batch = [sample[2] for sample in batch_samples]
                
                #Record real treatment
                for future_dict in future_dict_batch:
                    true_treatments_list.append(future_dict['current_treatments'])

                #Batch Generate Treatment Plan
                t1 = time.time()
                treatments_batch, outputs_batch, steps_used_batch = self.generate_treatment_plan_batch(
                    history_dict_batch, goal_batch, dataset_collection, future_dict_batch,
                    future_length=current_tau, early_stop=True, sample=sample,  **kwargs
                )
                print(f"batch_idx:{batch_idx}, used time: {time.time() - t1}")

                # print(f"evaluate_cip outputs_batch[0]: {outputs_batch[0].squeeze()}")
                # print(f"evaluate_cip future_dict_batch: {future_dict_batch[0]['outputs'].squeeze()}")
                # break

                #Processing Batch Results
                for i, (treatments, outputs, steps_used) in enumerate(zip(treatments_batch, outputs_batch, steps_used_batch)):
                    predictions = outputs[-1] if len(outputs) > 0 else None
                    
                    if predictions is not None:
                        goal_np = goal_batch[i] if isinstance(goal_batch[i], np.ndarray) else goal_batch[i].cpu().numpy()
                        mse = (((predictions - goal_np) * scaling_factor) ** 2).mean()
                        mse_values.append(mse)
                        
                        if mse < getattr(self, 'goal_threshold', 0.005) ** 2:
                            success_count += 1
                        
                        predictions_list.append(predictions)
                        treatments_list.append(treatments)
                        processed_samples.append(batch_samples[i])
                        steps_used_list.append(steps_used)
                        
                        if 'current_treatments' in future_dict_batch[i]:
                            self._calculate_treatment_similarity(future_dict_batch[i], treatments, treatment_similarities)

            self.train()

            metrics = self._calculate_metrics(mse_values, success_count, treatment_similarities, steps_used_list, current_tau, dataset_collection)
            self._print_evaluation_results(metrics, current_tau, logger)
            all_metrics[current_tau] = metrics
        
        path = os.path.join(config.exp.result_dir, f'treatment_{self.config.exp.seed}.pkl')

        with open(path, 'wb') as file:
            pickle.dump(treatments_list, file)

        self.tau = orig_tau
        return all_metrics

    def evaluate_and_log_case_studies(self, dataset_collection, config, logger, model_name=None, max_tau=6, sample=100, size=100, case_study_results=None, **kwargs):
        """
        The outcome trajectory is generated and documented for a given individual case for subsequent visualization and analysis.
        """
        set_seed(self.config['exp']['seed'])
        
        # device = getattr(self, 'device', 'cuda')
        device = 'cuda'
        start_time = time.time()
        if model_name is None:
            model_name = self.__class__.__name__
        
        if case_study_results is None:
            case_study_results = {}

        logger.info(f"--- Starting Case Study for model: {model_name} ---")

        path = os.path.join(config.exp.result_dir, f'treatment_{self.config.exp.seed}.pkl')

        with open(path, 'rb') as file:
            treatments_load = pickle.load(file)
        # print(f"treatments_load[patient_id]:{treatments_load}")

        #1. Load the full dataset for indexing
        try:
            data = dataset_collection.val_f.data_original if not self.config['exp']['test'] else dataset_collection.test_f.data_original
        except:
            data = dataset_collection.val_f.data if not self.config['exp']['test'] else dataset_collection.test_f.data
        
        dataloader = get_dataloader(CIPDataset(data, self.config), batch_size=len(data['outputs']), shuffle=False)
        all_samples = convert_dataloader_to_samples(dataloader)

        for patient_id in range(size):
            if patient_id >= len(all_samples):
                logger.warning(f"Patient ID {patient_id} is out of bounds. Skipping.")
                continue

            logger.info(f"Processing patient_id: {patient_id}")
            history_dict, future_dict, goal = all_samples[patient_id]
            goal_np = goal if isinstance(goal, np.ndarray) else goal.cpu().numpy()

            if patient_id not in case_study_results:
                initial_outcome = history_dict['outputs'][:, -1, :] 
                case_study_results[patient_id] = {
                    'initial_outcome': initial_outcome.squeeze(),
                    'goal': goal_np.squeeze(),
                    'models': {}
                }
                
                true_treatments = future_dict['current_treatments']
                true_treatments_tensor = torch.FloatTensor(true_treatments).to(device)
                true_outcome_trajectory = []

                for t_step in range(1, true_treatments_tensor.shape[1] + 1):
                    outcome_at_t = dataset_collection.val_f.simulate_output_after_actions(
                        history_dict, true_treatments_tensor[:, :t_step, :], dataset_collection.train_scaling_params
                    )
                    true_outcome_trajectory.append(outcome_at_t)
                
                case_study_results[patient_id]['ground_truth_outcomes'] = np.array(true_outcome_trajectory).squeeze()

            treatment = torch.from_numpy(treatments_load[patient_id]).squeeze().unsqueeze(0)
            if 'vcip' in model_name:
                model_preds_batch = []
                _, outputs_batch, _, = self.generate_treatment_plan_batch(
                    [history_dict], [goal], dataset_collection, [future_dict],
                    future_length=max_tau, 
                    early_stop=True, 
                    sample=sample, 
                    return_model_predictions=False, #< -- Passing new parameters
                    eval_simulation=True, treatments_load=treatment,
                    **kwargs
                )
            else:
                _, outputs_batch, _, model_preds_batch = self.generate_treatment_plan_batch(
                    [history_dict], [goal], dataset_collection, [future_dict],
                    future_length=max_tau, 
                    early_stop=True, 
                    sample=sample, 
                    return_model_predictions=True, #< -- Passing new parameters
                    eval_simulation=True, treatments_load=treatment,
                    **kwargs
                )
            
            outcome_trajectory = outputs_batch[0]
    
            if patient_id == 0:
                print('-' * 100)
                print(f"evaluate log out: {outcome_trajectory.squeeze()}")
                print(f"evaluate log turth: {np.array(true_outcome_trajectory).squeeze()}")
                # exit()

            #* * * Modify: store results in a dictionary and conditionally add internal forecasts * * *
            if 'vcip' in model_name:
                case_study_results[patient_id]['models'][model_name] = np.array(outcome_trajectory).squeeze()
            else:
                model_predicted_trajectory = model_preds_batch[0]
                case_study_results[patient_id]['models'][model_name] = {
                    'simulated_outcome_trajectory': np.array(outcome_trajectory).squeeze()
                }
                if model_predicted_trajectory.size > 0:
                    case_study_results[patient_id]['models'][model_name]['predicted_outcome_trajectory'] = np.array(model_predicted_trajectory).squeeze()

        self.train()
        logger.info(f"--- Case Study for model: {model_name} Finished ---")
        used_time = time.time() - start_time
        return case_study_results, used_time/size

    #Other auxiliary methods remain unchanged
    def _validate_sample_keys(self, history_dict):
        required_keys = ['current_treatments', 'prev_treatments', 'prev_outputs', 'active_entries']
        for key in required_keys:
            if key not in history_dict:
                return False
        if getattr(self, 'has_vitals', False) and 'vitals' not in history_dict:
            return False
        return True

    def _calculate_treatment_similarity(self, future_dict, treatments, treatment_similarities):
        actual_treatments = future_dict['current_treatments']
        if isinstance(treatments, torch.Tensor):
            treatments = treatments.cpu().numpy()
        if isinstance(actual_treatments, torch.Tensor):
            actual_treatments = actual_treatments.cpu().numpy()
        min_len = min(len(actual_treatments), len(treatments))
        if min_len > 0:
            actual_treatments, pred_treatments = actual_treatments[:min_len], treatments[:min_len]
            if np.linalg.norm(actual_treatments) > 0 and np.linalg.norm(pred_treatments) > 0:
                similarity = np.sum(actual_treatments * pred_treatments) / (np.linalg.norm(actual_treatments) * np.linalg.norm(pred_treatments))
                treatment_similarities.append(similarity)

    def _calculate_metrics(self, mse_values, success_count, treatment_similarities, steps_used_list, tau, dataset_collection):
        success_rate = success_count / len(mse_values) if mse_values else 0
        avg_mse = np.mean(mse_values) if mse_values else float('inf')
        
        avg_rmse = np.sqrt(avg_mse)
        avg_similarity = np.mean(treatment_similarities) if treatment_similarities else 0.0
        avg_steps = np.mean(steps_used_list) if steps_used_list else 0

        metrics = {
            'success_rate': success_rate,
            'avg_mse': avg_mse,
            'avg_rmse': avg_rmse,
            'avg_treatment_similarity': avg_similarity,
            'num_evaluated': len(mse_values),
            'avg_steps_used': avg_steps,
            'early_stop_rate': 1.0 - (avg_steps / tau) if steps_used_list else 0
        }
        return metrics

    def _print_evaluation_results(self, metrics, tau, logger):
        model_name = self.__class__.__name__
        logger.info(f"{model_name} assessment results (tau = {tau}):")
        logger.info(f"Success rate: {metrics ['success_rate']: .2%}")
        logger.info(f"Average MSE: {metrics ['avg_mse']: .6f}")
        logger.info(f"Average RMSE: {metrics ['avg_rmse']: .6f}")
        logger.info(f"Average intervention similarity: {metrics ['avg_treatment_similarity']: .4f}")
        logger.info(f"Average steps used: {metrics ['avg_steps_used']: .2f}/{tau}")
        logger.info(f"Early stop rate: {metrics ['early_stop_rate']: .2%}")

    def optimize_interventions(self, encoder=None, num_iterations=100, learning_rate=0.01, batch_size=64, logger=None):
        results = ['\n' + '-' * 50]
        tau = self.tau
        if 'mimic' in self.config['dataset']['name']:
            if self.config['exp']['test']:
                batch_size = int(self.config['dataset']['max_number'] * self.config['dataset']['split']['test'])
            else:
                batch_size = int(self.config['dataset']['max_number'] * self.config['dataset']['split']['val'])
        elif 'tumor' in self.config['dataset']['name']:
            if self.config['exp']['test']:
                batch_size = self.config['dataset']['num_patients']['test']
            else:
                batch_size = self.config['dataset']['num_patients']['val']

        for i in range(1, tau+1):
            logger.info(f'start predicting results for tau={i} ...')
            self.tau = i
            self.config.exp.tau = i
            print(self.tau, self.config.exp.tau)
            treatments, loss = self.optimize_interventions_onetime(encoder=encoder, num_iterations=num_iterations, learning_rate=learning_rate, batch_size=batch_size)

            path = os.path.join(self.config.exp.result_dir, f'treatment_{self.config.exp.seed}.pkl')
            if i == tau:
                with open(path, 'wb') as file:
                    pickle.dump(treatments.cpu().numpy(), file)

            results.append(f"Optimized interventions for tau={i}, {num_iterations} iterations, lr={learning_rate}: {loss}")
            logger.info(f"Optimized interventions for tau={i}, RMSE: {loss}")
        return '\n'.join(results)