#!/usr/bin/env python3
"""
Sensor-based POD/SVD baseline that fits modes on a fixed subset and reconstructs fields from sparse observations.
"""

import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
from pathlib import Path
from typing import Dict, Optional, Tuple
import pyvista as pv
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
from scipy.linalg import svd
import itertools
from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator

# Import dataset and evaluation functions from training pipeline
import sys
sys.path.append('..')
from training.graph_dataset import make_dataset
from training.train import create_test_plot, _mean_angular_similarity, _graph_ssim_speed


def load_coordinates_from_slice(slice_file: Path) -> Optional[np.ndarray]:
    """Load (x,y) coordinates from a VTU slice file (used for plotting)."""
    try:
        slc = pv.read(str(slice_file))
        coords = slc.points[:, :2]
        return coords
    except Exception as e:
        print(f"Warning: Could not load coordinates from {slice_file}: {e}")
        return None


def vectorize_field(field_2c: np.ndarray) -> np.ndarray:
    """(N_points, 2) -> (N_points*2,) column vector."""
    return field_2c.reshape(-1)


def devectorize_field(vec: np.ndarray) -> np.ndarray:
    """(N_points*2,) -> (N_points, 2)."""
    N_total = vec.shape[0]
    assert N_total % 2 == 0, "Vector length must be divisible by 2."
    return vec.reshape(N_total // 2, 2)


class SensorPOD:
    """Low-Cost SVD (POD) using modes and sensor observations to reconstruct fields."""

    def __init__(self, n_modes: int = 15):
        self.n_modes = n_modes
        self.U: Optional[np.ndarray] = None    # spatial modes (N*2, k)
        self.S: Optional[np.ndarray] = None    # singular values (k,)
        self.N_points: Optional[int] = None
        self._full_s: Optional[np.ndarray] = None

    def fit_modes(self, case_fields: Dict[int, np.ndarray]) -> None:
        """Build data matrix from a set of cases on the same mesh and run SVD.

        Args:
            case_fields: dict {case_number: field (N_points, 2)}
        """
        if len(case_fields) == 0:
            raise ValueError("No fields provided for SVD fit.")

        cases_sorted = sorted(case_fields.keys())
        sample_field = case_fields[cases_sorted[0]]
        self.N_points = int(sample_field.shape[0])
        X = np.stack([vectorize_field(np.asarray(case_fields[c])) for c in cases_sorted], axis=1)  # (N*2, M)

        U, s, Vt = svd(X, full_matrices=False)
        self._full_s = s.copy()
        k = min(self.n_modes, Vt.shape[1])
        self.U = U[:, :k]
        self.S = s[:k]

    def reconstruct_from_sensors(self, obs_idx: np.ndarray, y_obs: np.ndarray) -> np.ndarray:
        """Reconstruct field given sensor indices and observed values.

        Args:
            obs_idx: indices into the flattened field (length m), with 2 channels interleaved per node.
            y_obs: observed values (m,)

        Returns:
            recon field (N_points, 2)
        """
        assert self.U is not None and self.S is not None and self.N_points is not None
        Phi = self.U  # (N*2, k)
        A = Phi[obs_idx, :]  # (m, k)
        # Solve least squares for coefficients: min ||A c - y||_2
        c, *_ = np.linalg.lstsq(A, y_obs, rcond=None)
        recon_vec = Phi @ c  # (N*2,)
        return devectorize_field(recon_vec)



def collect_case_map(test_ds,
                     needed_cases: Optional[set] = None,
                     max_iterations: int = 10000,
                     verbose: bool = False) -> Tuple[Dict[int, np.ndarray], Optional[np.ndarray], Optional[Tuple[np.ndarray, np.ndarray]]]:
    """Iterate over dataset and collect a map: case_number -> target field (N_points, 2).
       Also returns the coordinates (N_points, 2) and edge indices (senders, receivers) found in the first batch.
    """
    case_to_field: Dict[int, np.ndarray] = {}
    coords_found: Optional[np.ndarray] = None
    edges_found: Optional[Tuple[np.ndarray, np.ndarray]] = None

    iterator = itertools.cycle(test_ds)
    pulls = 0
    
    while pulls < max_iterations:
        try:
            batch = next(iterator)
        except StopIteration:
            break
            
        pulls += 1
        targets = batch['target_inputs']      # (batch_size, N_points, 2)
        case_nums = batch.get('angle_deg', batch.get('case_number'))

        # Extract coordinates and edges from the first batch once
        if coords_found is None and 'graph_structures' in batch:
            gs = batch['graph_structures']
            # Handle batch dimension in graph structures if present (e.g. collated)
            if isinstance(gs, dict):
                 if 'original_coordinates' in gs:
                     # Check if there's a batch dim. Usually original_coordinates is (2, N, 2) or (N, 2)
                     c = gs['original_coordinates']
                     if c.ndim == 3: 
                         coords_found = np.array(c[0])
                     else:
                         coords_found = np.array(c)
                 
                 if 'o2o_senders' in gs and 'o2o_receivers' in gs:
                     s = gs['o2o_senders']
                     r = gs['o2o_receivers']
                     if s.ndim == 2: s = s[0]
                     if r.ndim == 2: r = r[0]
                     edges_found = (np.array(s), np.array(r))
        
        targets = np.asarray(targets)
        case_nums = np.asarray(case_nums)

        for b in range(targets.shape[0]):
            cnum = int(case_nums[b])
            if cnum not in case_to_field:
                case_to_field[cnum] = targets[b]

        if needed_cases is not None:
            if needed_cases.issubset(case_to_field.keys()):
                break
        else:
            if len(case_to_field) >= 360:
                break
    
    if verbose:
        print(f"      Collected {len(case_to_field)} unique cases after {pulls} batches.")
    return case_to_field, coords_found, edges_found


def plot_and_save_spectra(output_dir: Path, singular_values: np.ndarray) -> None:
    """Save spectra arrays and create two simple plots: singular values and cumulative energy."""
    energy = singular_values**2
    energy_frac = energy / np.sum(energy) if np.sum(energy) > 0 else energy * 0
    cum_energy = np.cumsum(energy_frac)

    # Save arrays
    np.savez(output_dir / "spectra.npz",
             singular_values=singular_values,
             energy=energy,
             energy_fraction=energy_frac,
             cumulative_energy=cum_energy)

    # Plot: singular values
    plt.figure()
    plt.plot(np.arange(1, len(singular_values)+1), singular_values, marker='o')
    plt.xlabel("Mode index")
    plt.ylabel("Singular value")
    plt.title("Singular Values (Train subset)")
    plt.tight_layout()
    plt.savefig(output_dir / "singular_values.png", dpi=150, bbox_inches='tight')
    plt.close()

    # Plot: cumulative energy
    plt.figure()
    plt.plot(np.arange(1, len(cum_energy)+1), cum_energy, marker='o')
    plt.xlabel("Mode index")
    plt.ylabel("Cumulative energy fraction")
    plt.title("Cumulative Energy (Train subset)")
    plt.ylim(0.0, 1.01)
    plt.tight_layout()
    plt.savefig(output_dir / "cumulative_energy.png", dpi=150, bbox_inches='tight')
    plt.close()


def interpolate_to_reference(src_coords: np.ndarray, 
                             tgt_coords: np.ndarray, 
                             case_fields: Dict[int, np.ndarray]) -> Dict[int, np.ndarray]:
    """Interpolate a set of fields from src_mesh to tgt_mesh using linear (fallback nearest) interpolation."""
    print(f"      🔄 Interpolating {len(case_fields)} fields from {src_coords.shape[0]} nodes -> {tgt_coords.shape[0]} nodes...")
    
    # 1. Pre-compute Delaunay triangulation of source
    # (using scipy.spatial.Delaunay explicit pass is faster for many calls)
    try:
        src_tri = Delaunay(src_coords)
    except Exception as e:
        print(f"      ⚠️ Delaunay failed: {e}. Falling back to Nearest.")
        src_tri = None

    # 2. Build cKDTree for nearest neighbor fallback
    from scipy.spatial import cKDTree
    tree = cKDTree(src_coords)
    _, nn_indices = tree.query(tgt_coords)
    
    new_fields = {}
    
    for c, field in case_fields.items():
        if src_tri is not None:
            # Linear Interp
            lin_interp = LinearNDInterpolator(src_tri, field)
            new_val = lin_interp(tgt_coords)  # (N_tgt, 2)
            
            # Fill NaNs with Nearest
            mask_nan = np.isnan(new_val)
            if np.any(mask_nan):
                rows_with_nan = np.any(mask_nan, axis=1)
                new_val[rows_with_nan] = field[nn_indices[rows_with_nan]]
        else:
            # Full Nearest fallback
            new_val = field[nn_indices]

        new_fields[c] = new_val.astype(np.float32)
        
    return new_fields


def evaluate_sensor_pod(slice_root: str = "data_sliced_cropped_300k",
                        norm_stats: str = "normalization_cropped_300k_test/normalization_stats_train.nc",
                        test_z: str = "35",
                        train_z_values: str = "15,25,45,55",
                        n_modes: int = 15,
                        batch_size: int = 8,
                        output_dir: str = "baselines/results_sensor_pod",
                        seed: int = 42,
                        n_plot_samples: int = 10,
                        max_collect_iterations: int = 5000,
                        obs_frac: float = 0.05,
                        obs_count: int = -1,
                        obs_counts_str: str = ""):
    """Evaluate Sensor-POD by training modes on specific Z-slices and testing on a held-out Z-slice.
    HANDLES MESH MISMATCH via interpolation:
    1. Collect reference mesh from Test Z.
    2. Collect training snapshots from Train Zs.
    3. Interpolate all training snapshots onto the Test Z mesh reference.
    4. Fit POD on combined, interpolated data.
    5. Reconstruct Test Z cases (native mesh).
    """
    out_dir = Path(output_dir)
    # The output structure is now hierarchical due to multiple counts, so we just ensure root exists
    out_dir.mkdir(parents=True, exist_ok=True)
    
    train_zs = [float(z.strip()) for z in str(train_z_values).split(',') if z.strip()]
    test_zs = [float(z.strip()) for z in str(test_z).split(',') if z.strip()]

    print("🚀 Starting Sensor-POD Baseline (Generalized + Interpolated)")
    print(f"   Train Z-Slices: {train_zs}")
    print(f"   Test Z-Slices:  {test_zs}")
    print(f"   Modes:          {n_modes}")
    
    # Determine observation targets
    if obs_counts_str.strip():
        obs_targets = [int(x) for x in obs_counts_str.split(',') if x.strip()]
    elif obs_count > 0:
        obs_targets = [obs_count]
    else:
        obs_targets = [obs_frac] # Use float for fraction

    # Storage for aggregation across ALL test Zs
    # Key: target => list of metrics from all cases/Zs
    aggregated_rrmse = {t: [] for t in obs_targets}
    aggregated_ssim = {t: [] for t in obs_targets}
    aggregated_cosine = {t: [] for t in obs_targets}

    # Iterate over each requested Test Z independently
    for test_z_val in test_zs:
        print(f"\n========================================")
        print(f"🔄 Processing TEST Z = {test_z_val}")
        print(f"========================================")

        # --- 1. Establish Reference Mesh (Test Z) ---
        print(f"\n📦 Loading Reference Mesh from Test Z={test_z_val}...")
        test_ds = make_dataset(
            slice_root=slice_root,
            norm_stats_nc=norm_stats,
            batch_size=batch_size,
            shuffle=False,
            seed=seed + 999,
            is_training=False,
            fixed_z=int(test_z_val),
            drop_remainder=False,
            angle_stride=1,
        )
        needed_approx = set(range(1, 360, 2)).union(range(2, 361, 2))
        
        test_case_map, ref_coords, ref_edges = collect_case_map(test_ds, needed_cases=needed_approx, 
                                                     max_iterations=max_collect_iterations, verbose=True)
        
        if not test_case_map or ref_coords is None:
            print(f"Could not load test cases or coordinates for Z={test_z_val}; skipping.")
            continue
        
        print(f"   Reference Mesh (Test Z): {ref_coords.shape[0]} nodes.")

        # --- 2. Collect & Interpolate Training Snapshots ---
        print("\n📦 Collecting & Interpolating TRAINING snapshots...")
        
        # List to stack later
        all_train_snapshots = []
        
        for z_val in train_zs:
            # Skip if train Z is same as current test Z (strict holdout)
            if abs(z_val - test_z_val) < 1e-4:
                continue

            print(f"   -> Scanning Train Z={z_val} ...")
            ds = make_dataset(
                slice_root=slice_root,
                norm_stats_nc=norm_stats,
                batch_size=batch_size,
                shuffle=False,
                seed=seed,
                is_training=False,
                fixed_z=int(z_val),
                drop_remainder=False,
                angle_stride=1,
            )
            case_map, src_coords, _ = collect_case_map(ds, needed_cases=needed_approx, 
                                                    max_iterations=max_collect_iterations, verbose=True)
            
            if not case_map or src_coords is None:
                print(f"⚠️  Warning: No data found for Z={z_val}, skipping.")
                continue
                
            # Check if interpolation needed
            if src_coords.shape == ref_coords.shape and np.allclose(src_coords, ref_coords, atol=1e-5):
                # print("      (Mesh matches reference, no interpolation needed)")
                interp_map = case_map
            else:
                # INTERPOLATE
                interp_map = interpolate_to_reference(src_coords, ref_coords, case_map)
                
            # Add to list
            for c in sorted(interp_map.keys()):
                all_train_snapshots.append(interp_map[c])

        if not all_train_snapshots:
            print("No training snapshots collected! Aborting this test Z.")
            continue
            
        print(f"   Total training snapshots (interpolated): {len(all_train_snapshots)}")
        
        combined_train_map = {i: f for i, f in enumerate(all_train_snapshots)}

        # --- 3. Fit Modes ---
        model = SensorPOD(n_modes=n_modes)
        t0 = time.time()
        model.fit_modes(combined_train_map)
        fit_time = time.time() - t0
        print(f"🧠 Fitted SVD in {fit_time:.3f}s. Computed {min(len(combined_train_map), n_modes)} modes.")

        if model._full_s is not None:
            # Save generic spectra, maybe overwrite is fine or use subdirectory
            plot_and_save_spectra(out_dir, model._full_s)

        # --- 4. Loop over Observation Counts ---
        sorted_test_ids = sorted(test_case_map.keys())

        for target in obs_targets:
            is_frac = isinstance(target, float)
            label = f"frac{target}" if is_frac else f"obs{target}"

            sub_dir = out_dir / label
            sub_dir.mkdir(parents=True, exist_ok=True)
            
            predictions = {}
            sensor_masks = {}
            
            for c in sorted_test_ids:
                field = test_case_map[c]
                N = field.shape[0]
                if is_frac:
                    k = max(1, int(round(target * N)))
                else:
                    k = min(N, target)
                
                rng = np.random.RandomState(seed + c)
                node_idx = rng.choice(N, size=k, replace=False)
                
                obs_idx = np.empty(2*k, dtype=np.int64)
                obs_idx[0:2*k:2] = 2 * node_idx
                obs_idx[1:2*k:2] = 2 * node_idx + 1
                
                y_obs = vectorize_field(field)[obs_idx]
                
                recon = model.reconstruct_from_sensors(obs_idx, y_obs)
                predictions[c] = recon
                
                m = np.zeros(N, dtype=bool)
                m[node_idx] = True
                sensor_masks[c] = m
                
            # --- Metrics ---
            per_case_rrmse = {}
            per_case_ssim = {}
            per_case_cosine = {}
            
            for c in predictions:
                pred = predictions[c]
                targ = test_case_map[c]
                
                # RRMSE
                diff = pred - targ
                rmse = np.sqrt(np.mean(diff**2))
                denom = np.sqrt(np.mean(targ**2))
                rrmse = (rmse / denom) if denom > 1e-9 else 0.0
                per_case_rrmse[c] = float(rrmse)
                
                # SSIM (Graph-based)
                if ref_edges is not None:
                    send, recv = ref_edges
                    speed_true = np.linalg.norm(targ, axis=-1).astype(np.float64)
                    speed_pred = np.linalg.norm(pred, axis=-1).astype(np.float64)
                    # Convert to jnp if needed by the function, or ensure it handles numpy
                    # _graph_ssim_speed usually works with jnp, let's cast inputs
                    try:
                        gssim = _graph_ssim_speed(speed_true, speed_pred, send, recv, include_self=True)
                        per_case_ssim[c] = float(gssim)
                    except Exception as e:
                        # Fallback if jax import issues or shape mismatch
                        per_case_ssim[c] = 0.0
                else:
                    per_case_ssim[c] = 0.0
                
                # Angular Similarity
                sim = _mean_angular_similarity(targ, pred)
                per_case_cosine[c] = float(sim)
            
            # Aggregate stats
            rrmses = np.array(list(per_case_rrmse.values()))
            ssims = np.array(list(per_case_ssim.values()))
            cosines = np.array(list(per_case_cosine.values()))
            
            print(f"   [{label}] Z={test_z_val}: RRMSE={rrmses.mean():.4f}, SSIM={ssims.mean():.4f}, Cos={cosines.mean():.4f}")
            aggregated_rrmse[target].extend(rrmses.tolist())
            aggregated_ssim[target].extend(ssims.tolist())
            aggregated_cosine[target].extend(cosines.tolist())
            
            # --- Plotting (subset) ---
            plots_done = 0
            for c in sorted_test_ids:
                if plots_done >= n_plot_samples:
                    break
                
                coords_plot = ref_coords
                fig = create_test_plot(
                    original_data=test_case_map[c],
                    predicted_data=predictions[c],
                    coords=coords_plot,
                    case_number=c,
                    mae_error=per_case_rrmse[c],
                    slice_z=test_z_val,
                    obs_mask=sensor_masks[c],
                    observed_values=None
                )
                p = sub_dir / f"pod_test_z{int(test_z_val)}_case{c}.png"
                fig.savefig(p, dpi=150, bbox_inches='tight')
                plt.close(fig)
                plots_done += 1
                
    # --- Final Aggregation ---
    results_list = []
    print("\n📊 Final Aggregated Results (Average over all Test Zs & Cases):")
    for target in obs_targets:
        vals_r = np.array(aggregated_rrmse[target])
        vals_s = np.array(aggregated_ssim[target])
        vals_c = np.array(aggregated_cosine[target])
        
        if vals_r.size > 0:
            mean_r = vals_r.mean(); std_r = vals_r.std()
            mean_s = vals_s.mean(); std_s = vals_s.std()
            mean_c = vals_c.mean(); std_c = vals_c.std()
            count = vals_r.size
            
            print(f"   Target {target}: RRMSE={mean_r:.4f}±{std_r:.4f}, SSIM={mean_s:.4f}±{std_s:.4f}, Cos={mean_c:.4f}±{std_c:.4f} (N={count})")
            results_list.append({
                "target": target,
                "mean_rrmse": mean_r,
                "std_rrmse": std_r,
                "mean_ssim": mean_s,
                "std_ssim": std_s,
                "mean_cosine": mean_c,
                "std_cosine": std_c,
                "num_samples": count
            })

    # Save Aggregate CSV
    import csv
    with open(out_dir / "summary_metrics.csv", "w", newline='') as f:
        fieldnames = ["target", "mean_rrmse", "std_rrmse", "mean_ssim", "std_ssim", "mean_cosine", "std_cosine", "num_samples"]
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        w.writerows(results_list)
    print(f"\n✅ Done. Summary saved to {out_dir}/summary_metrics.csv")
    return



def main():
    parser = argparse.ArgumentParser(description="Sensor-based POD Baseline (Generalization across Z-slices)")
    parser.add_argument('--slice_root', type=str, default='data_sliced_cropped_300k')
    parser.add_argument('--norm_stats', type=str, default='normalization_cropped_300k_test/normalization_stats_train.nc')
    
    # Split Config
    parser.add_argument('--test_z', type=str, default='35,40', help='Z-slice(s) to evaluate on (comma-separated, unseen during POD fit)')
    parser.add_argument('--train_z_values', type=str, default='15,20,28,45', help='Comma-separated Z-slices for training the basis')
    
    parser.add_argument('--n_modes', type=int, default=15)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--output_dir', type=str, default='baselines/results_sensor_pod')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--n_plot_samples', type=int, default=1)
    parser.add_argument('--max_collect_iterations', type=int, default=5000)
    parser.add_argument('--obs_frac', type=float, default=0.01)
    parser.add_argument('--obs_count', type=int, default=-1, help='Single obs count (legacy)')
    parser.add_argument('--obs_counts', type=str, default='300,3000,8000', help='Comma-separated list of observation counts to sweep e.g. "300,1000,5000"')

    args = parser.parse_args()

    evaluate_sensor_pod(
        slice_root=args.slice_root,
        norm_stats=args.norm_stats,
        test_z=args.test_z,
        train_z_values=args.train_z_values,
        n_modes=args.n_modes,
        batch_size=args.batch_size,
        output_dir=args.output_dir,
        seed=args.seed,
        n_plot_samples=args.n_plot_samples,
        max_collect_iterations=args.max_collect_iterations,
        obs_frac=args.obs_frac,
        obs_count=args.obs_count,
        obs_counts_str=args.obs_counts,
    )

if __name__ == '__main__':
    main()
