"""Data manager - manages all data files during KOH process"""
import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional


class DataManager:
    """Manages data files during KOH process.
    
    File structure:
        - seed_points.csv: Warmup initial points
        - history.csv: All HF experiment data
        - hf_predictions.csv: HF points during optimization (with iteration)
        - lf_predictions.csv: LF points during optimization (with iteration)
    """
    
    def __init__(
        self,
        task_data_dir: str,
        feature_names: List[str],
        target_name: str = "f",
        file_prefix: Optional[str] = None,
        objective_transform: float = 1.0
    ):
        """Initialize data manager.
        
        Args:
            task_data_dir: Task data directory (e.g., data/FeCr)
            feature_names: List of feature column names
            target_name: Target column name
            file_prefix: File name prefix (e.g., FeCr_20250101), used to distinguish different runs
        """
        self.task_data_dir = Path(task_data_dir)
        self.feature_names = feature_names
        self.target_name = target_name
        self.file_prefix = file_prefix
        # Objective transformation: stored_y = objective_transform * true_y
        # - Maximization task: objective_transform = 1.0
        # - Minimization task in maximization framework: objective_transform = -1.0 (stored_y = -true_y)
        self.objective_transform = float(objective_transform)
        if self.objective_transform == 0.0:
            raise ValueError("objective_transform cannot be 0.0")
        
        # File paths
        self.seed_points_path = self._build_path("seed_points.csv")
        self.history_path = self._build_path("history.csv")
        self.hf_predictions_path = self._build_path("hf_predictions.csv")
        self.lf_predictions_path = self._build_path("lf_predictions.csv")
        
        # Initialize data
        self.seed_points_df = self._load_or_create_csv(
            self.seed_points_path, 
            self.feature_names + [self.target_name],
            "initial points"
        )
        self.history_df = self._load_or_create_csv(
            self.history_path,
            self.feature_names + [self.target_name],
            "history data"
        )
        hf_columns = self.feature_names + [self.target_name, "iteration", "best_objective"]
        self.hf_predictions_df = self._load_or_create_csv(
            self.hf_predictions_path,
            hf_columns,
            "HF predictions"
        )
        self._ensure_columns(self.hf_predictions_df, hf_columns)
        self.lf_predictions_df = self._load_or_create_csv(
            self.lf_predictions_path,
            self.feature_names + ["mu_LF", "sigma2_LF", "iteration"],
            "LF predictions"
        )
    
    def _build_path(self, filename: str) -> Path:
        """Build file path based on prefix."""
        self.task_data_dir.mkdir(parents=True, exist_ok=True)
        if self.file_prefix:
            name = f"{self.file_prefix}_{filename}"
        else:
            name = filename
        return self.task_data_dir / name
    
    def _load_or_create_csv(self, path: Path, columns: List[str], name: str) -> pd.DataFrame:
        """Load or create CSV file."""
        if path.exists():
            df = pd.read_csv(path)
            # Ensure numeric columns are float64 type (avoid issues with object type)
            for col in df.columns:
                if col in self.feature_names or col in [self.target_name, "mu_LF", "sigma2_LF"]:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
            return df
        else:
            df = pd.DataFrame(columns=columns)
            return df
    
    def _ensure_columns(self, df: pd.DataFrame, columns: List[str]):
        """Ensure DataFrame contains specified columns, fill NaN for missing columns."""
        for col in columns:
            if col not in df.columns:
                df[col] = np.nan
    
    def save_all(self):
        """Save all data to CSV files."""
        self.seed_points_df.to_csv(self.seed_points_path, index=False)
        self.history_df.to_csv(self.history_path, index=False)
        self.hf_predictions_df.to_csv(self.hf_predictions_path, index=False)
        self.lf_predictions_df.to_csv(self.lf_predictions_path, index=False)
    
    # ==================== Add Data ====================
    
    def add_seed_point(self, x: dict, y_H: float):
        """Add initial point (Warmup phase)."""
        if self._point_exists(self.seed_points_df, x):
            return
        
        row = {**x, self.target_name: y_H}
        self.seed_points_df = pd.concat([self.seed_points_df, pd.DataFrame([row])], ignore_index=True)
        
        # Also add to history
        if not self._point_exists(self.history_df, x):
            self.history_df = pd.concat([self.history_df, pd.DataFrame([row])], ignore_index=True)
        
        return
    
    def add_hf_experiment(self, x: dict, y_H: float, iteration: int):
        """Add HF experiment (optimization phase).
        
        Note: If the point already exists in history (e.g., initial points from warmup phase), it will still be added to hf_predictions.
        """
        # Check if already in hf_predictions (avoid duplicate addition)
        if self._point_exists(self.hf_predictions_df, x):
            return
        
        # If not in history, add to history
        if not self._point_exists(self.history_df, x):
            history_row = {**x, self.target_name: y_H}
            self.history_df = pd.concat([self.history_df, pd.DataFrame([history_row])], ignore_index=True)
        
        # Add to hf_predictions (even if point already exists in history, still add to hf_predictions)
        if len(self.hf_predictions_df) > 0:
            if "best_objective" in self.hf_predictions_df.columns and self.hf_predictions_df["best_objective"].notna().any():
                previous_best = self.hf_predictions_df["best_objective"].max(skipna=True)
            else:
                previous_best = self.hf_predictions_df[self.target_name].max(skipna=True)
            current_best = max(float(previous_best), float(y_H))
        else:
            current_best = float(y_H)
        
        hf_row = {**x, self.target_name: y_H, "iteration": iteration, "best_objective": current_best}
        self.hf_predictions_df = pd.concat([self.hf_predictions_df, pd.DataFrame([hf_row])], ignore_index=True)
        
        return
    
    def add_lf_prediction(self, x: dict, mu_LF: float, sigma2_LF: float, iteration: int):
        """Add LF prediction (optimization phase).
        
        Strategy: Always add new records to allow GP to learn more information
        """
        # Create unique identifier (includes iteration)
        row = {**x, "mu_LF": mu_LF, "sigma2_LF": sigma2_LF, "iteration": iteration}
        self.lf_predictions_df = pd.concat([self.lf_predictions_df, pd.DataFrame([row])], ignore_index=True)
    
    def add_lf_predictions_batch(self, X_batch: List[dict], mu_LF_batch: np.ndarray, 
                                 sigma2_LF_batch: np.ndarray, iteration: int):
        """Batch add LF predictions.
        
        Strategy: Always add new records to allow GP to learn more information
        """
        added = 0
        
        for x, mu, sigma2 in zip(X_batch, mu_LF_batch, sigma2_LF_batch):
            if not np.isnan(mu):
                self.add_lf_prediction(x, mu, sigma2, iteration)
                added += 1
        
        return
    
    # ==================== Query Data ====================
    
    def get_history_data(self) -> List[Dict]:
        """Get history data (for ICL).
        
        Returns:
            [{"x": {...}, "y": value}, ...]
        """
        if len(self.history_df) == 0:
            return []
        
        history = []
        for _, row in self.history_df.iterrows():
            x = {feat: row[feat] for feat in self.feature_names}
            # LLM uses "original task space" y (true_y), avoid feeding negative values to LLM under min→max transformation
            y_stored = row[self.target_name]
            y_true = float(y_stored) / self.objective_transform
            history.append({"x": x, "y": y_true})
        
        return history
    
    def get_history_exclude_point(self, x: dict) -> List[Dict]:
        """Get history excluding specified point (for Leave-one-out ICL)."""
        if len(self.history_df) == 0:
            return []
        
        # Find rows that do not match (exclude rows where all features match x)
        mask = pd.Series([True] * len(self.history_df))
        for feat in self.feature_names:
            mask &= (self.history_df[feat] == x[feat])
        mask = ~mask  # Invert to get rows that don't match
        
        filtered_df = self.history_df[mask]
        
        history = []
        for _, row in filtered_df.iterrows():
            x_i = {feat: row[feat] for feat in self.feature_names}
            y_stored = row[self.target_name]
            y_true = float(y_stored) / self.objective_transform
            history.append({"x": x_i, "y": y_true})
        
        return history
    
    def get_history_points(self) -> tuple:
        """Get history X and y (for training residual GP).
        
        Returns:
            (X, y) where X is (n, d) and y is (n,)
        """
        if len(self.history_df) == 0:
            return np.array([]).reshape(0, len(self.feature_names)), np.array([])
        
        # Ensure float64 type (avoid object type)
        X = self.history_df[self.feature_names].values.astype(np.float64)
        y = self.history_df[self.target_name].values.astype(np.float64)
        return X, y
    
    def get_lf_training_data(self) -> tuple:
        """Get all LF prediction data (for training low-fidelity GP).
        
        Returns:
            (X, mu_LF, sigma2_LF) all numpy arrays
        """
        if len(self.lf_predictions_df) == 0:
            n_feat = len(self.feature_names)
            return np.array([]).reshape(0, n_feat), np.array([]), np.array([])
        
        # Ensure float64 type (avoid object type)
        X = self.lf_predictions_df[self.feature_names].values.astype(np.float64)
        mu_LF = self.lf_predictions_df["mu_LF"].values.astype(np.float64)
        sigma2_LF = self.lf_predictions_df["sigma2_LF"].values.astype(np.float64)
        
        return X, mu_LF, sigma2_LF
    
    def recompute_all_lf_predictions(self, predictor, iteration: int = None, only_new: bool = True):
        """Recompute LOO low-fidelity predictions for history points.
        
        Args:
            predictor: Predictor
            iteration: Iteration number
            only_new: If True, only recompute new history points (performance optimization)
        """
        if len(self.history_df) == 0:
            return
        
        if only_new:
            # Only compute new history points (points without LOO prediction)
            points_to_compute = []
            for idx, row in self.history_df.iterrows():
                x_dict = {feat: row[feat] for feat in self.feature_names}
                # Check if LOO prediction record already exists for this point
                existing_pred = self._get_existing_lf_prediction(x_dict)
                if existing_pred is None:
                    points_to_compute.append((idx, x_dict))
            
            if len(points_to_compute) == 0:
                return
        else:
            # Recompute all history points (backward compatibility)
            points_to_compute = [(idx, {feat: row[feat] for feat in self.feature_names}) 
                                for idx, row in self.history_df.iterrows()]
            pass
        
        success_count = 0
        for idx, x_dict in points_to_compute:
            history_exclude = self.get_history_exclude_point(x_dict)
            
            try:
                mu_LF, sigma2_LF, _ = predictor.predict(x_dict, history_exclude)
                
                # Get original iteration (if exists)
                if iteration is None:
                    existing_pred = self._get_existing_lf_prediction(x_dict)
                    current_iteration = existing_pred.get("iteration", 0) if existing_pred is not None else 0
                else:
                    current_iteration = iteration
                
                self.add_lf_prediction(x_dict, mu_LF, sigma2_LF, current_iteration)
                success_count += 1
                
            except Exception as e:
                pass
        
        return
    
    def recompute_non_history_lf_predictions(self, predictor, iteration: int = None, only_if_history_changed: bool = True):
        """Recompute LF predictions for all non-history points (using full history for batch prediction).
        
        Strategy:
        1. Find all points in lf_predictions_df that are not in history_df
        2. Use full history data to batch predict these points
        3. Update LF predictions for these points
        
        Args:
            predictor: Predictor
            iteration: Iteration number
            only_if_history_changed: If True, only recompute when history has new points (performance optimization)
        """
        if len(self.lf_predictions_df) == 0:
            return
        
        if len(self.history_df) == 0:
            return
        
        if only_if_history_changed:
            # Check if there are new history points
            # Strategy: If all history points have LOO predictions, history has no new points, can skip
            history_points_with_loo = 0
            for idx, row in self.history_df.iterrows():
                x_dict = {feat: row[feat] for feat in self.feature_names}
                if self._get_existing_lf_prediction(x_dict) is not None:
                    history_points_with_loo += 1
            
            # If all history points have LOO predictions, history has no new points, can skip recomputation of non-history points
            if history_points_with_loo == len(self.history_df) and history_points_with_loo > 0:
                return
        
        # Find all non-history LF points (deduplicate)
        non_history_points_dict = {}  # Use dict for deduplication, key is tuple representation of point
        
        for idx, row in self.lf_predictions_df.iterrows():
            x_dict = {feat: row[feat] for feat in self.feature_names}
            if not self._point_exists(self.history_df, x_dict):
                # Use tuple as key for deduplication
                point_key = tuple(sorted(x_dict.items()))
                if point_key not in non_history_points_dict:
                    non_history_points_dict[point_key] = x_dict
        
        non_history_points = list(non_history_points_dict.values())
        
        if len(non_history_points) == 0:
            return
        
        # Get full history data
        history = self.get_history_data()
        
        # Batch prediction (using full history)
        try:
            # Dynamically adjust batch_size based on task dimension: high-dimensional tasks need smaller batch_size to avoid token limits
            input_dim = len(self.feature_names)
            if input_dim >= 10:
                batch_size = 10  # High-dimensional tasks (e.g., COF)
            elif input_dim >= 5:
                batch_size = 15  # Medium-dimensional tasks
            else:
                batch_size = 20  # Low-dimensional tasks (e.g., FeCr, CB)
            
            # Use predictor's batch prediction function
            mu_LF_batch, sigma2_LF_batch = predictor.predict_batch(
                non_history_points, history, batch_size=batch_size
            )
            
            # Update LF predictions for these points
            success_count = 0
            for i, (x_dict, mu, sigma2) in enumerate(zip(non_history_points, mu_LF_batch, sigma2_LF_batch)):
                if not np.isnan(mu):
                    # Get original iteration (if exists)
                    if iteration is None:
                        existing_pred = self._get_existing_lf_prediction(x_dict)
                        current_iteration = existing_pred.get("iteration", 0) if existing_pred is not None else 0
                    else:
                        current_iteration = iteration
                    
                    # Add new prediction (preserve historical records)
                    self.add_lf_prediction(x_dict, mu, sigma2, current_iteration)
                    success_count += 1
                else:
                    pass
            
            return
            
        except Exception as e:
            return
    
    def _get_existing_lf_prediction(self, x_dict: dict) -> Optional[dict]:
        """Get existing LF prediction result."""
        if len(self.lf_predictions_df) == 0:
            return None
        
        # Find matching rows
        mask = pd.Series([True] * len(self.lf_predictions_df))
        for feat in self.feature_names:
            if feat in x_dict:
                mask &= (self.lf_predictions_df[feat] == x_dict[feat])
        
        if mask.any():
            row = self.lf_predictions_df.loc[mask].iloc[0]
            return {
                "mu_LF": row["mu_LF"],
                "sigma2_LF": row["sigma2_LF"],
                "iteration": row["iteration"]
            }
        return None
    
    # ==================== Helper Methods ====================
    
    def _point_exists(self, df: pd.DataFrame, x: dict) -> bool:
        """Check if point exists in DataFrame."""
        if len(df) == 0:
            return False
        
        mask = pd.Series([True] * len(df))
        for feat in self.feature_names:
            mask &= (df[feat] == x[feat])
        
        return mask.any()
    
    def _update_lf_prediction(self, x: dict, mu_LF: float, sigma2_LF: float):
        """Update existing LF prediction."""
        mask = pd.Series([True] * len(self.lf_predictions_df))
        for feat in self.feature_names:
            mask &= (self.lf_predictions_df[feat] == x[feat])
        
        if mask.any():
            self.lf_predictions_df.loc[mask, "mu_LF"] = mu_LF
            self.lf_predictions_df.loc[mask, "sigma2_LF"] = sigma2_LF

