#!/usr/bin/env python3

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import os
import sys

# Set threading limits BEFORE importing numpy/scipy/sklearn to avoid OpenBLAS errors
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

import argparse
import csv
import numpy as np
import torch
import pandas as pd
from torch.utils.data import DataLoader
from torchgeo.datasets import So2Sat

# Scikit-learn imports
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.metrics import make_scorer, accuracy_score, f1_score, balanced_accuracy_score

# Add src to path
from pathlib import Path
project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))

# Import baselines
from src.models.mm_baselines import PPD, JIVE, AJIVE, SLIDE, ShIndICA

def compute_metrics(rep, labels, cv_folds=5, seed=42):
    """
    Compute Accuracy, Balanced Accuracy, and Macro F1 using 5-fold Stratified CV.
    """
    results = {}
    
    if rep is None or labels is None:
        return results

    X = np.asarray(rep)
    y = np.asarray(labels)
    
    # flatten if needed
    if X.ndim > 2:
        X = X.reshape(X.shape[0], -1)
        
    # Define scorers
    scoring = {
        'acc': 'accuracy',
        'bal_acc': 'balanced_accuracy',
        'macro_f1': 'f1_macro'
    }
    
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=seed)

    # --- 1. Logistic Regression (Linear Probe) ---
    # Using class_weight='balanced' is CRITICAL for So2Sat
    try:
        clf = LogisticRegression(max_iter=1000, solver='liblinear', class_weight='balanced')
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['logistic_acc'] = np.mean(scores['test_acc'])
        results['logistic_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['logistic_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"Logistic failed: {e}")

    # --- 2. KNN ---
    try:
        clf = KNeighborsClassifier(n_neighbors=5) # k=5 is more robust than k=1
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['knn_acc'] = np.mean(scores['test_acc'])
        results['knn_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['knn_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"KNN failed: {e}")

    # --- 3. Gaussian NB ---
    try:
        clf = GaussianNB()
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['gnb_acc'] = np.mean(scores['test_acc'])
        results['gnb_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['gnb_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"GNB failed: {e}")

    return results


def print_metrics_summary(metrics, representation_name):
    """Print a concise summary of classification accuracies for a single evaluation."""
    if not metrics:
        print(f"No metrics for {representation_name}")
        return

    parts = []
    if 'logistic_acc' in metrics:
        parts.append(f"LogReg Acc={metrics['logistic_acc']:.4f}")
    if 'knn_acc' in metrics:
        parts.append(f"KNN Acc={metrics['knn_acc']:.4f}")
    if 'gnb_acc' in metrics:
        parts.append(f"GNB Acc={metrics['gnb_acc']:.4f}")
    # Macro F1 scores (when available)
    if 'logistic_f1' in metrics:
        parts.append(f"LogReg F1={metrics['logistic_f1']:.4f}")
    if 'knn_f1' in metrics:
        parts.append(f"KNN F1={metrics['knn_f1']:.4f}")
    if 'gnb_f1' in metrics:
        parts.append(f"GNB F1={metrics['gnb_f1']:.4f}")

    if parts:
        print(f"Results for {representation_name}: {', '.join(parts)}")
    else:
        print(f"Results for {representation_name}: (no accuracy scores available)")

def load_so2sat_raw(root_dir, n_samples=None, seed=42):
    """Loads and normalizes raw So2Sat data exactly like training script"""
    print("Loading Raw So2Sat Data (Train split only for evaluation)...")
    
    dataset = So2Sat(root=root_dir, version="2", split="train", transforms=None, checksum=False)
    
    radar_list = []
    optical_list = []
    labels_list = []
    
    samples_loaded = 0
    # Load a subset if specified, otherwise full (WARNING: Full is large)
    # If n_samples is None, we might want to cap it for evaluation speed unless on a big server
    limit = n_samples if n_samples else 20000 # Default to 20k for eval speed if not specified
    
    loader = DataLoader(dataset, batch_size=256, shuffle=False)
    
    for batch in loader:
        image = batch["image"]
        label = batch["label"]
        
        radar = image[:, 0:8, :, :].float()
        optical = image[:, 8:11, :, :].float()
        
        radar_list.append(radar)
        optical_list.append(optical)
        labels_list.append(label)
        
        samples_loaded += radar.shape[0]
        if limit and samples_loaded >= limit:
            break
            
    radar = torch.cat(radar_list, dim=0)
    optical = torch.cat(optical_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    
    # --- Normalization (Match training script behavior) ---
    print("Normalizing Raw Data...")

    # Ensure float tensors
    radar = radar.float()
    optical = optical.float()

    # Radar: if any channel has negative values, shift so min is zero (training does this across train+val)
    min_r = radar.amin(dim=(0, 2, 3), keepdim=True)
    if (min_r < 0).any():
        if True:
            print("Detected negative values in radar bands; shifting to make non-negative before log1p.")
        radar = radar - min_r

    # Apply log1p
    radar = torch.log1p(radar)

    # Compute 99th percentile using a flattened sample to avoid memory blowup (cap at 1e6 samples)
    flat = radar.flatten()
    flat_num = flat.numel()
    if flat_num > 1000000:
        idx = torch.randperm(flat_num)[:1000000]
        flat_sample = flat[idx]
    else:
        flat_sample = flat
    q99 = torch.quantile(flat_sample, 0.99)
    radar = torch.clamp(radar, max=q99)

    # Per-channel z-score (compute mean/std across spatial dims)
    mean_r = radar.mean(dim=(0, 2, 3), keepdim=True)
    std_r = radar.std(dim=(0, 2, 3), keepdim=True)
    # Guard against zero std
    zero_std_mask = (std_r == 0)
    if zero_std_mask.any():
        if True:
            print(f"Warning: zero std encountered in radar channels at indices {zero_std_mask.view(-1).nonzero(as_tuple=False).squeeze().tolist()}, setting to 1.0 to avoid division by zero.")
        std_r[zero_std_mask] = 1.0
    radar = (radar - mean_r) / (std_r + 1e-8)

    # Optical: compute per-channel z-score using training stats (we only have train here)
    mean_o = optical.mean(dim=(0, 2, 3), keepdim=True)
    std_o = optical.std(dim=(0, 2, 3), keepdim=True)
    zero_std_mask_o = (std_o == 0)
    if zero_std_mask_o.any():
        if True:
            print(f"Warning: zero std encountered in optical channels at indices {zero_std_mask_o.view(-1).nonzero(as_tuple=False).squeeze().tolist()}, setting to 1.0 to avoid division by zero.")
        std_o[zero_std_mask_o] = 1.0
    optical = (optical - mean_o) / (std_o + 1e-8)
    
    return radar.numpy(), optical.numpy(), labels.numpy()

def main():
    parser = argparse.ArgumentParser(description="Evaluate So2Sat Representations")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--prefix', type=str, default='so2sat')
    parser.add_argument('--results_dir', type=str, default=project_config.SO2SAT_RESULTS_DIR)
    parser.add_argument('--data_dir', type=str, default=project_config.SO2SAT_DATA_ROOT)
    parser.add_argument('--n_samples_eval', type=int, default=10000, help="Number of samples for baseline calculation (slow)")
    args = parser.parse_args()

    print("="*80)
    print(f"Evaluating So2Sat: {args.prefix}_rseed-{args.seed}")
    print("="*80)

    # --- 1. Load Saved Representations (From Training) ---
    model_prefix = f"{args.prefix}_rseed-{args.seed}"
    
    try:
        # Load files saved by training script
        # Note: Training script saves [Shared, Radar_Spec, Optical_Spec]
        # But check if they exist first
        rep0 = np.load(f"{args.results_dir}/{model_prefix}_rep0.npy") # Shared
        rep1 = np.load(f"{args.results_dir}/{model_prefix}_rep1.npy") # Radar Spec
        rep2 = np.load(f"{args.results_dir}/{model_prefix}_rep2.npy") # Optical Spec
        
        # Load labels saved during training (Ensures perfect alignment)
        train_labels = np.load(f"{args.results_dir}/{model_prefix}_train_labels.npy")
        
        print("✓ Loaded saved representations and labels from training run.")
    except FileNotFoundError as e:
        print(f"Error loading saved representations: {e}")
        return

    # --- 1b. Load Unimodal Representations ---
    unimodal_radar_rep = None
    unimodal_optical_rep = None
    
    try:
        unimodal_radar_rep = np.load(f"{args.results_dir}/radar_reps_seed{args.seed}.npy")
        print(f"✓ Loaded unimodal radar representations (Dim: {unimodal_radar_rep.shape[1]})")
    except FileNotFoundError:
        print("  No unimodal radar representations found.")
    
    try:
        unimodal_optical_rep = np.load(f"{args.results_dir}/optical_reps_seed{args.seed}.npy")
        print(f"✓ Loaded unimodal optical representations (Dim: {unimodal_optical_rep.shape[1]})")
    except FileNotFoundError:
        print("  No unimodal optical representations found.")

    all_results = []

    # --- 2. Evaluate Learned Representations ---
    print("\n--- Evaluating Learned Representations ---")
    
    """
    # Define configurations
    configs = [
        ('Shared', rep0),
        ('Radar_Specific', rep1),
        ('Optical_Specific', rep2),
        #('Shared + Radar', np.concatenate([rep0, rep1], axis=1)),
        #('Shared + Optical', np.concatenate([rep0, rep2], axis=1)),
        ('All Concat', np.concatenate([rep0, rep1, rep2], axis=1))
    ]
    
    for name, rep in configs:
        print(f"Evaluating {name} (Dim: {rep.shape[1]})...")
        metrics = compute_metrics(rep, train_labels)
        print_metrics_summary(metrics, name)
        metrics.update({'representation': name, 'dim': rep.shape[1], 'type': 'Model'})
        all_results.append(metrics)
    

    # --- 2b. Evaluate Unimodal Representations ---
    if unimodal_radar_rep is not None:
        print(f"\n--- Evaluating Unimodal Radar AE ---")
        print(f"Evaluating Unimodal Radar (Dim: {unimodal_radar_rep.shape[1]})...")
        # Truncate labels if unimodal reps have fewer samples
        n_unimodal = unimodal_radar_rep.shape[0]
        unimodal_labels = train_labels[:n_unimodal]
        metrics = compute_metrics(unimodal_radar_rep, unimodal_labels)
        print_metrics_summary(metrics, 'Unimodal_Radar_AE')
        metrics.update({'representation': 'Unimodal_Radar_AE', 'dim': unimodal_radar_rep.shape[1], 'type': 'Unimodal'})
        all_results.append(metrics)
    
    if unimodal_optical_rep is not None:
        print(f"\n--- Evaluating Unimodal Optical AE ---")
        print(f"Evaluating Unimodal Optical (Dim: {unimodal_optical_rep.shape[1]})...")
        # Truncate labels if unimodal reps have fewer samples
        n_unimodal = unimodal_optical_rep.shape[0]
        unimodal_labels = train_labels[:n_unimodal]
        metrics = compute_metrics(unimodal_optical_rep, unimodal_labels)
        print_metrics_summary(metrics, 'Unimodal_Optical_AE')
        metrics.update({'representation': 'Unimodal_Optical_AE', 'dim': unimodal_optical_rep.shape[1], 'type': 'Unimodal'})
        all_results.append(metrics)
    
    if unimodal_radar_rep is not None and unimodal_optical_rep is not None:
        print(f"Evaluating Unimodal Concatenated (Dim: {unimodal_radar_rep.shape[1] + unimodal_optical_rep.shape[1]})...")
        # Use the minimum number of samples across both modalities
        n_unimodal = min(unimodal_radar_rep.shape[0], unimodal_optical_rep.shape[0])
        unimodal_concat = np.concatenate([unimodal_radar_rep[:n_unimodal], unimodal_optical_rep[:n_unimodal]], axis=1)
        unimodal_labels = train_labels[:n_unimodal]
        metrics = compute_metrics(unimodal_concat, unimodal_labels)
        print_metrics_summary(metrics, 'Unimodal_Concat')
        metrics.update({'representation': 'Unimodal_Concat', 'dim': unimodal_concat.shape[1], 'type': 'Unimodal'})
        all_results.append(metrics)
    exit()
    """

    ## --- 3. Evaluate Raw Data & Baselines ---
    print("\n--- Evaluating Baselines on Raw Data ---")
    
    # Load raw data (Flattened)
    raw_radar, raw_optical, raw_labels = load_so2sat_raw(
        args.data_dir, n_samples=args.n_samples_eval, seed=args.seed
    )

    # Flatten for baselines
    raw_radar_flat = raw_radar.reshape(raw_radar.shape[0], -1)
    raw_optical_flat = raw_optical.reshape(raw_optical.shape[0], -1)
    
    # 3a. Raw Data Baselines
    #print(f"Evaluating Raw Radar (Dim: {raw_radar_flat.shape[1]})...")
    #m_rad = compute_metrics(raw_radar_flat, raw_labels)
    #print_metrics_summary(m_rad, 'Raw_Radar')
    #m_rad.update({'representation': 'Raw_Radar', 'dim': raw_radar_flat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_rad)
    
    #print(f"Evaluating Raw Optical (Dim: {raw_optical_flat.shape[1]})...")
    #m_opt = compute_metrics(raw_optical_flat, raw_labels)
    #print_metrics_summary(m_opt, 'Raw_Optical')
    #m_opt.update({'representation': 'Raw_Optical', 'dim': raw_optical_flat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_opt)
    
    #print(f"Evaluating Raw Concatenated (Dim: {raw_radar_flat.shape[1] + raw_optical_flat.shape[1]})...")
    #raw_concat = np.concatenate([raw_radar_flat, raw_optical_flat], axis=1)
    #m_cat = compute_metrics(raw_concat, raw_labels)
    #print_metrics_summary(m_cat, 'Raw_Concat')
    #m_cat.update({'representation': 'Raw_Concat', 'dim': raw_concat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_cat)

    # 3b. Decomposition Baselines
    # We use the raw data as input for JIVE/CCA etc.
    baseline_methods = [
        ('JIVE', lambda: JIVE()),
        ('AJIVE', lambda: AJIVE()),
        ('SLIDE', lambda: SLIDE()), # SLIDE can be slow on high dims
        ('ShIndICA', lambda: ShIndICA(joint_rank_options=[5, 10, 20, 50])),
    ]
    
    data_for_baselines = [
        torch.FloatTensor(raw_radar_flat),
        torch.FloatTensor(raw_optical_flat)
    ]
    
    for method_name, constructor in baseline_methods:
        print(f"\nRunning {method_name}...")
        try:
            method = constructor()
            # These methods expect list of tensors
            decomposed, rank_info = method.decompose(data_for_baselines)
            
            print(f"{method_name} subspaces: {list(decomposed.keys())}")
            print(f"{method_name} rank info: {rank_info}")
            
            # Evaluate each subspace
            for key, val in decomposed.items():
                sub_rep = val.cpu().numpy() if torch.is_tensor(val) else val
                print(f"  Evaluating {method_name} {key} (Dim: {sub_rep.shape[1]})...")
                m = compute_metrics(sub_rep, raw_labels)
                print_metrics_summary(m, f"{method_name}_{key}")
                m.update({'representation': f"{method_name}_{key}", 'dim': sub_rep.shape[1], 'type': 'Baseline'})
                all_results.append(m)
                
            # Concatenate all subspaces
            all_subs = [val.cpu().numpy() if torch.is_tensor(val) else val for val in decomposed.values()]
            concat_subs = np.concatenate(all_subs, axis=1)
            print(f"  Evaluating {method_name} Concat (Dim: {concat_subs.shape[1]})...")
            m = compute_metrics(concat_subs, raw_labels)
            print_metrics_summary(m, f"{method_name}_Concat")
            m.update({'representation': f"{method_name}_Concat", 'dim': concat_subs.shape[1], 'type': 'Baseline'})
            all_results.append(m)
            
        except Exception as e:
            print(f"  {method_name} failed: {e}")

    # --- 4. Save & Print Results ---
    print("\n" + "="*100)
    print(f"{'Representation':<30} {'Dim':<6} {'LogReg F1':<10} {'LogReg Bal':<10} {'KNN F1':<10}")
    print("-" * 100)
    
    for res in all_results:
        name = res.get('representation', 'N/A')
        dim = res.get('dim', 0)
        
        # Format scores
        lr_f1 = f"{res.get('logistic_f1', 0):.4f}"
        lr_bal = f"{res.get('logistic_bal_acc', 0):.4f}"
        knn_f1 = f"{res.get('knn_f1', 0):.4f}"
        
        print(f"{name:<30} {dim:<6} {lr_f1:<10} {lr_bal:<10} {knn_f1:<10}")

    # Save CSV
    df = pd.DataFrame(all_results)
    save_path = f"{args.results_dir}/{model_prefix}_full_evaluation.csv"
    df.to_csv(save_path, index=False)
    print(f"\nResults saved to {save_path}")

if __name__ == '__main__':
    main()