"""
eval_exp2_ood_detectors.py

Compare OOD action rates using three detectors.

State-conditioned mode uses (s,a) pairs to align with theory:
  - 1-NN: nearest action among dataset actions from nearby states
  - LOF / Mahalanobis: operate on concatenated (s,a) features

1. 1-NN (existing baseline)

2. Gaussian KDE

3. Local Outlier Factor (LOF)

4. Mahalanobis distance

Input:

- saved_best_models/**/rollouts_base_env.npz (contains actions from policy)

- D4RL offline dataset (contains training distribution actions)

Output:

- tables/exp2_ood_detectors_{env}.csv

"""

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.neighbors import LocalOutlierFactor, NearestNeighbors
from sklearn.neighbors import KernelDensity
import gym
import d4rl
import time

class OODDetector:
    """Base class for OOD detection"""
    
    def __init__(self, dataset_actions):
        """
        Args:
            dataset_actions: (N, action_dim) numpy array of actions from offline dataset
        """
        self.dataset_actions = dataset_actions
        
    def fit(self):
        """Fit the detector to the dataset"""
        raise NotImplementedError
    
    def compute_ood_rate(self, eval_actions):
        """
        Compute fraction of eval_actions that are OOD
        
        Args:
            eval_actions: (M, action_dim) numpy array of actions from policy rollouts
            
        Returns:
            ood_rate: float, fraction of OOD actions
        """
        raise NotImplementedError

class OneNNDetector(OODDetector):
    """1-NN detector from the paper (Section E.4)"""
    
    def __init__(self, dataset_actions, kappa=3.0):
        super().__init__(dataset_actions)
        self.kappa = kappa
        self.threshold = None
        
    def fit(self):
        """Compute threshold based on median NN distance"""
        start_time = time.time()
        print("[1-NN] Fitting 1-NN OOD detector (computing nearest neighbors)...")
        # For each dataset action, find nearest neighbor distance
        nbrs = NearestNeighbors(n_neighbors=2, metric='euclidean')
        nbrs.fit(self.dataset_actions)
        
        # distances[i,0] is distance to self (0), distances[i,1] is to nearest neighbor
        distances, _ = nbrs.kneighbors(self.dataset_actions)
        nn_distances = distances[:, 1]  # Take second column (nearest neighbor, not self)
        
        # Compute median
        med_nn = np.median(nn_distances)
        
        # Set threshold
        self.threshold = self.kappa * med_nn
        
        elapsed = time.time() - start_time
        print(f"[1-NN] Done. median_nn={med_nn:.4f}, threshold={self.threshold:.4f}, time={elapsed:.2f}s")
        
    def compute_ood_rate(self, eval_actions):
        """Check if eval actions exceed threshold"""
        if self.threshold is None:
            self.fit()
            
        # For each eval action, find distance to nearest dataset action
        nbrs = NearestNeighbors(n_neighbors=1, metric='euclidean')
        nbrs.fit(self.dataset_actions)
        distances, _ = nbrs.kneighbors(eval_actions)
        
        # Count how many exceed threshold
        ood_mask = distances[:, 0] > self.threshold
        ood_rate = ood_mask.mean()
        
        return ood_rate * 100  # Return as percentage

class StateConditionedOneNNDetector:
    """1-NN detector with state-conditioned action neighborhood"""

    def __init__(self, dataset_states, dataset_actions, kappa=3.0, k_state=10,
                 max_threshold_points=50000):
        self.dataset_states = dataset_states
        self.dataset_actions = dataset_actions
        self.kappa = kappa
        self.k_state = k_state
        self.max_threshold_points = max_threshold_points
        self.state_nbrs = None
        self.threshold = None

    def fit(self):
        start_time = time.time()
        print("[1-NN|state] Fitting state-conditioned 1-NN detector...")
        self.state_nbrs = NearestNeighbors(n_neighbors=self.k_state + 1, metric='euclidean')
        self.state_nbrs.fit(self.dataset_states)

        n = len(self.dataset_states)
        n_thresh = min(self.max_threshold_points, n)
        idx = np.random.choice(n, size=n_thresh, replace=False)

        min_dists = []
        for i in idx:
            _, nn_idx = self.state_nbrs.kneighbors(
                self.dataset_states[i].reshape(1, -1),
                n_neighbors=self.k_state + 1
            )
            nn_idx = nn_idx[0]
            nn_idx = nn_idx[nn_idx != i]
            neighbor_actions = self.dataset_actions[nn_idx]
            diffs = neighbor_actions - self.dataset_actions[i]
            dists = np.linalg.norm(diffs, axis=1)
            min_dists.append(dists.min())

        med_nn = np.median(min_dists)
        self.threshold = self.kappa * med_nn
        elapsed = time.time() - start_time
        print(f"[1-NN|state] Done. median_nn={med_nn:.4f}, threshold={self.threshold:.4f}, time={elapsed:.2f}s")

    def compute_ood_rate(self, eval_states, eval_actions):
        if self.threshold is None or self.state_nbrs is None:
            self.fit()

        _, nn_idx = self.state_nbrs.kneighbors(eval_states, n_neighbors=self.k_state)
        neighbor_actions = self.dataset_actions[nn_idx]
        diffs = neighbor_actions - eval_actions[:, None, :]
        dists = np.linalg.norm(diffs, axis=2)
        min_dists = dists.min(axis=1)
        ood_mask = min_dists > self.threshold
        return ood_mask.mean() * 100

class KDEDetector(OODDetector):
    """Gaussian Kernel Density Estimator"""
    
    def __init__(self, dataset_actions, bandwidth='scott', quantile=0.01,
                 max_threshold_points=50000):
        """
        Args:
            bandwidth: 'scott', 'silverman', or float
            quantile: threshold quantile (lower = stricter OOD detection)
            max_threshold_points: subsample size for threshold estimation
        """
        super().__init__(dataset_actions)
        self.bandwidth = bandwidth
        self.quantile = quantile
        self.max_threshold_points = max_threshold_points
        self.kde = None
        self.threshold = None
        
    def fit(self):
        """Fit KDE and compute threshold"""
        start_time = time.time()
        print("[KDE] Fitting Gaussian KDE and estimating log-likelihood threshold...")
        # Determine bandwidth
        if self.bandwidth == 'scott':
            # Scott's rule: n^(-1/(d+4))
            n, d = self.dataset_actions.shape
            bw = n ** (-1 / (d + 4))
        elif self.bandwidth == 'silverman':
            # Silverman's rule
            n, d = self.dataset_actions.shape
            bw = (n * (d + 2) / 4) ** (-1 / (d + 4))
        else:
            bw = self.bandwidth
            
        # Fit KDE
        self.kde = KernelDensity(kernel='gaussian', bandwidth=bw)
        self.kde.fit(self.dataset_actions)
        
        # ---- ここからサブサンプリングでしきい値推定 ----
        n = len(self.dataset_actions)
        n_thresh = min(self.max_threshold_points, n)
        idx = np.random.choice(n, size=n_thresh, replace=False)
        subset = self.dataset_actions[idx]
        # Compute log-likelihood threshold at quantile on subset
        log_likelihoods = self.kde.score_samples(subset)
        self.threshold = np.quantile(log_likelihoods, self.quantile)
        # -----------------------------------------------
        elapsed = time.time() - start_time
        print(f"[KDE] Done. bandwidth={bw:.4f}, threshold={self.threshold:.4f}, time={elapsed:.2f}s")
        
    def compute_ood_rate(self, eval_actions):
        """Compute OOD rate based on log-likelihood threshold"""
        if self.kde is None:
            self.fit()
            
        log_likelihoods = self.kde.score_samples(eval_actions)
        ood_mask = log_likelihoods < self.threshold
        ood_rate = ood_mask.mean()
        
        return ood_rate * 100

class MahalanobisDetector(OODDetector):
    """Single Gaussian + Mahalanobis distance detector"""
    
    def __init__(self, dataset_actions, quantile=0.99,
                 max_threshold_points=50000, eps=1e-6):
        """
        Args:
            quantile: upper quantile for distance threshold (e.g. 0.99)
            max_threshold_points: subsample size for threshold estimation
            eps: diagonal jitter for covariance matrix
        """
        super().__init__(dataset_actions)
        self.quantile = quantile
        self.max_threshold_points = max_threshold_points
        self.eps = eps
        self.mu = None
        self.cov_inv = None
        self.threshold = None
        
    def fit(self):
        """Fit Gaussian and compute Mahalanobis distance threshold"""
        start_time = time.time()
        print("[Mahalanobis] Fitting single Gaussian and estimating Mahalanobis distance threshold...")
        X = np.asarray(self.dataset_actions)
        self.mu = X.mean(axis=0)
        # Covariance + small diagonal regularization
        cov = np.cov(X, rowvar=False)
        cov = cov + self.eps * np.eye(cov.shape[0])
        self.cov_inv = np.linalg.inv(cov)
        
        # サブサンプルで距離分布を近似
        n = len(X)
        n_thresh = min(self.max_threshold_points, n)
        idx = np.random.choice(n, size=n_thresh, replace=False)
        d2 = self._mahalanobis_squared(X[idx])
        # 「大きいほど OOD」なので上側分位をしきい値に
        self.threshold = np.quantile(d2, self.quantile)
        
        elapsed = time.time() - start_time
        print(f"[Mahalanobis] Done. quantile={self.quantile}, threshold={self.threshold:.4f}, time={elapsed:.2f}s")
        
    def _mahalanobis_squared(self, X):
        """Compute squared Mahalanobis distance for each row in X"""
        X = np.asarray(X)
        diff = X - self.mu    # (N, d)
        # einsum: diff * Σ^{-1} * diff^T -> (N,)
        return np.einsum('ij,jk,ik->i', diff, self.cov_inv, diff)
        
    def compute_ood_rate(self, eval_actions):
        """OOD if Mahalanobis distance^2 > threshold"""
        if self.threshold is None:
            self.fit()
            
        d2_eval = self._mahalanobis_squared(eval_actions)
        ood_mask = d2_eval > self.threshold
        ood_rate = ood_mask.mean()
        
        return ood_rate * 100

class LOFDetector(OODDetector):
    """Local Outlier Factor detector"""
    
    def __init__(self, dataset_actions, n_neighbors=20, contamination=0.01):
        """
        Args:
            n_neighbors: number of neighbors for LOF
            contamination: expected proportion of outliers (used for threshold)
        """
        super().__init__(dataset_actions)
        self.n_neighbors = n_neighbors
        self.contamination = contamination
        self.lof = None
        
    def fit(self):
        """Fit LOF model"""
        start_time = time.time()
        print("[LOF] Fitting Local Outlier Factor model...")
        self.lof = LocalOutlierFactor(
            n_neighbors=self.n_neighbors,
            contamination=self.contamination,
            novelty=True  # Enables predict on new data
        )
        self.lof.fit(self.dataset_actions)
        
        elapsed = time.time() - start_time
        print(f"[LOF] Done. n_neighbors={self.n_neighbors}, contamination={self.contamination}, time={elapsed:.2f}s")
        
    def compute_ood_rate(self, eval_actions):
        """Compute OOD rate using LOF predictions"""
        if self.lof is None:
            self.fit()
            
        # LOF returns -1 for outliers, 1 for inliers
        predictions = self.lof.predict(eval_actions)
        ood_mask = predictions == -1
        ood_rate = ood_mask.mean()
        
        return ood_rate * 100

def load_d4rl_actions(env_name):
    """Load actions from D4RL dataset"""
    env = gym.make(env_name)
    dataset = d4rl.qlearning_dataset(env)
    actions = dataset['actions']
    env.close()
    return actions

def load_d4rl_state_action(env_name):
    """Load states and actions from D4RL dataset"""
    env = gym.make(env_name)
    dataset = d4rl.qlearning_dataset(env)
    states = dataset['observations']
    actions = dataset['actions']
    env.close()
    return states, actions

def _flatten_episode_array(arr):
    if arr.dtype == object:
        return np.concatenate([ep for ep in arr], axis=0)
    if arr.ndim == 3:
        return arr.reshape(-1, arr.shape[-1])
    return arr

def load_rollout_actions(rollout_path):
    """Load actions from saved rollouts"""
    try:
        data = np.load(rollout_path, allow_pickle=True)
    except Exception as e:
        raise ValueError(f"Failed to load {rollout_path}: {e}")
    
    # Check if 'actions' key exists
    if 'actions' not in data:
        available_keys = list(data.keys())
        raise KeyError(f"'actions' key not found in {rollout_path}. Available keys: {available_keys}")
    
    actions_list = data['actions']
    actions = _flatten_episode_array(actions_list)
    
    data.close()
    return actions

def load_rollout_state_action(rollout_path):
    """Load states and actions from saved rollouts"""
    try:
        data = np.load(rollout_path, allow_pickle=True)
    except Exception as e:
        raise ValueError(f"Failed to load {rollout_path}: {e}")

    if 'states' not in data or 'actions' not in data:
        available_keys = list(data.keys())
        raise KeyError(f"'states'/'actions' keys not found in {rollout_path}. Available keys: {available_keys}")

    states = _flatten_episode_array(data['states'])
    actions = _flatten_episode_array(data['actions'])
    data.close()

    n = min(len(states), len(actions))
    return states[:n], actions[:n]

def evaluate_single_run(env_name, algo, seed, detectors_config, state_conditioned=False):
    """
    Evaluate OOD rates for a single run
    
    Args:
        env_name: e.g., 'halfcheetah-medium-expert-v2'
        algo: 'radac' or 'oraac'
        seed: int
        detectors_config: dict of detector configurations
        
    Returns:
        dict with OOD rates for each detector
    """
    print(f"\nEvaluating {algo} on {env_name}, seed={seed}")
    
    # Load rollout actions (policy distribution)
    rollout_path = Path(f"saved_best_models/{env_name}_exp2_{algo}_{algo}-1/seed{seed}/rollouts_base_env.npz")
    
    if not rollout_path.exists():
        print(f"WARNING: Rollout not found at {rollout_path}")
        print(f"  Expected path: {rollout_path.absolute()}")
        # Try to find similar files
        parent_dir = rollout_path.parent
        if parent_dir.exists():
            similar_files = list(parent_dir.glob("*.npz"))
            if similar_files:
                print(f"  Found similar files in {parent_dir}: {[f.name for f in similar_files]}")
        return None
        
    if state_conditioned:
        dataset_states, dataset_actions = load_d4rl_state_action(env_name)
        eval_states, eval_actions = load_rollout_state_action(rollout_path)
        print(f"Loaded {len(dataset_actions)} dataset (s,a) pairs")
        print(f"Loaded {len(eval_actions)} evaluation (s,a) pairs")
    else:
        dataset_actions = load_d4rl_actions(env_name)
        eval_actions = load_rollout_actions(rollout_path)
        dataset_states = None
        eval_states = None
        print(f"Loaded {len(dataset_actions)} dataset actions")
        print(f"Loaded {len(eval_actions)} evaluation actions")

    if state_conditioned:
        dataset_features = np.concatenate([dataset_states, dataset_actions], axis=1)
        eval_features = np.concatenate([eval_states, eval_actions], axis=1)
        detectors = {
            '1-NN': StateConditionedOneNNDetector(
                dataset_states,
                dataset_actions,
                kappa=detectors_config['1nn_kappa'],
                k_state=detectors_config['state_k'],
                max_threshold_points=detectors_config['state_max_threshold_points'],
            ),
            'LOF': LOFDetector(
                dataset_features,
                n_neighbors=detectors_config['lof_neighbors'],
                contamination=detectors_config['lof_contamination'],
            ),
            'Mahalanobis': MahalanobisDetector(
                dataset_features,
                quantile=detectors_config['mahalanobis_quantile'],
                max_threshold_points=detectors_config['mahalanobis_max_threshold_points'],
            ),
        }
    else:
        dataset_features = dataset_actions
        eval_features = eval_actions
        detectors = {
            '1-NN': OneNNDetector(dataset_actions, kappa=detectors_config['1nn_kappa']),
        # 'KDE': KDEDetector(
        #     dataset_actions,
        #     bandwidth=detectors_config['kde_bandwidth'],
        #     quantile=detectors_config['kde_quantile'],
        #     max_threshold_points=detectors_config['kde_max_threshold_points'],
        # ),
            'LOF': LOFDetector(
                dataset_actions,
                n_neighbors=detectors_config['lof_neighbors'],
                contamination=detectors_config['lof_contamination'],
            ),
            'Mahalanobis': MahalanobisDetector(
                dataset_actions,
                quantile=detectors_config['mahalanobis_quantile'],
                max_threshold_points=detectors_config['mahalanobis_max_threshold_points'],
            ),
        }
    
    # Fit detectors
    for name, detector in detectors.items():
        detector.fit()
    
    # Compute OOD rates
    results = {}
    for name, detector in detectors.items():
        if state_conditioned and name == '1-NN':
            ood_rate = detector.compute_ood_rate(eval_states, eval_actions)
        else:
            ood_rate = detector.compute_ood_rate(eval_features)
        results[name] = ood_rate
        print(f"{name}: OOD rate = {ood_rate:.2f}%")
    
    return results

def main():
    """Run OOD detector comparison for Exp2"""
    
    # Configuration
    envs = [
        # 'halfcheetah-medium-expert-v2',
        'hopper-medium-expert-v2',
        'walker2d-medium-expert-v2'
    ]
    
    algos = ['oraac_diffusion']
    seeds = [0, 1, 2]
    
    # Detector hyperparameters
    # These can be tuned based on your specific needs
    detectors_config = {
        '1nn_kappa': 3.0,          # From paper
        # 'kde_bandwidth': 'scott',  # or 'silverman' or float
        # 'kde_quantile': 0.01,      # 1% threshold (lower tail)
        # 'kde_max_threshold_points': 50000,
        'lof_neighbors': 20,
        'lof_contamination': 0.01,  # Expect 1% outliers
        'mahalanobis_quantile': 0.95,          # upper tail
        'mahalanobis_max_threshold_points': 50000,
        'state_k': 10,
        'state_max_threshold_points': 50000,
    }

    state_conditioned = True
    output_tag = 'state' if state_conditioned else 'action'
    
    # Store all results
    all_results = []
    
    # Evaluate each combination
    for env_name in envs:
        
        for algo in algos:
            for seed in seeds:
                print(f"\n{'*'*60}")
                print(f"Evaluating {algo} on {env_name}, seed={seed}")
                results = evaluate_single_run(
                    env_name,
                    algo,
                    seed,
                    detectors_config,
                    state_conditioned=state_conditioned,
                )
                
                if results is not None:
                    all_results.append({
                        'env': env_name,
                        'algo': algo,
                        'seed': seed,
                        **results
                    })
                print(f"{'*'*60}\n")
    
    # Convert to DataFrame
    if len(all_results) == 0:
        print("\nWARNING: No results collected. Please check:")
        print("  1. d4rl is installed")
        print("  2. Rollout files exist at expected paths")
        print("  3. Environment names and algorithm names are correct")
        return
    
    df = pd.DataFrame(all_results)
    
    # Compute mean and std across seeds
    summary = df.groupby(['env', 'algo']).agg({
        '1-NN': ['mean', 'std'],
        # 'KDE': ['mean', 'std'],
        'LOF': ['mean', 'std'],
        'Mahalanobis': ['mean', 'std'],
    }).reset_index()
    
    # Format for display
    summary.columns = ['_'.join(col).strip('_') for col in summary.columns.values]
    
    # Save results
    output_dir = Path('figures/ood_detectors')
    output_dir.mkdir(exist_ok=True)
    
    # Save detailed results
    df.to_csv(output_dir / f'exp2_ood_detectors_detailed_{output_tag}.csv', index=False)
    
    # Save summary
    summary.to_csv(output_dir / f'exp2_ood_detectors_summary_{output_tag}.csv', index=False)
    
    # Print summary table
    print("\n" + "="*80)
    print("SUMMARY: OOD Action Rates (%) - Mean ± Std across 3 seeds")
    print("="*80)
    
    for env in envs:
        print(f"\n{env}:")
        env_data = summary[summary['env'] == env]
        
        if len(env_data) == 0:
            print("  No data available")
            continue
            
        for _, row in env_data.iterrows():
            algo = row['algo']
            print(f"  {algo.upper()}:")
            print(f"    1-NN:       {row['1-NN_mean']:.2f} ± {row['1-NN_std']:.2f}")
            # print(f"    KDE:        {row['KDE_mean']:.2f} ± {row['KDE_std']:.2f}")
            print(f"    LOF:        {row['LOF_mean']:.2f} ± {row['LOF_std']:.2f}")
            print(f"    Mahalanobis:{row['Mahalanobis_mean']:.2f} ± {row['Mahalanobis_std']:.2f}")
    
    print("\n" + "="*80)
    print(f"Results saved to {output_dir}")

if __name__ == '__main__':
    main()
