#!/usr/bin/env python3

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import os
import sys

# Set threading limits BEFORE importing numpy/scipy/sklearn to avoid OpenBLAS errors
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

import argparse
import csv
import numpy as np
import torch
import pandas as pd
import h5py
from torch.utils.data import DataLoader

# Scikit-learn imports
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.metrics import make_scorer, accuracy_score, f1_score, balanced_accuracy_score

# Add src to path
from pathlib import Path
project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))

# Import baselines
from src.models.mm_baselines import PPD, JIVE, AJIVE, SLIDE, ShIndICA

def compute_metrics(rep, labels, cv_folds=5, seed=42):
    """
    Compute Accuracy, Balanced Accuracy, and Macro F1 using 5-fold Stratified CV.
    """
    results = {}
    
    if rep is None or labels is None:
        return results

    X = np.asarray(rep)
    y = np.asarray(labels)
    
    # flatten if needed
    if X.ndim > 2:
        X = X.reshape(X.shape[0], -1)
        
    # Define scorers
    scoring = {
        'acc': 'accuracy',
        'bal_acc': 'balanced_accuracy',
        'macro_f1': 'f1_macro'
    }
    
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=seed)

    # --- 1. Logistic Regression (Linear Probe) ---
    # Using class_weight='balanced' is CRITICAL for So2Sat
    try:
        clf = LogisticRegression(max_iter=1000, solver='liblinear', class_weight='balanced')
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['logistic_acc'] = np.mean(scores['test_acc'])
        results['logistic_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['logistic_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"Logistic failed: {e}")

    # --- 2. KNN ---
    try:
        clf = KNeighborsClassifier(n_neighbors=5) # k=5 is more robust than k=1
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['knn_acc'] = np.mean(scores['test_acc'])
        results['knn_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['knn_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"KNN failed: {e}")

    # --- 3. Gaussian NB ---
    try:
        clf = GaussianNB()
        scores = cross_validate(clf, X, y, cv=cv, scoring=scoring)
        results['gnb_acc'] = np.mean(scores['test_acc'])
        results['gnb_bal_acc'] = np.mean(scores['test_bal_acc'])
        results['gnb_f1'] = np.mean(scores['test_macro_f1'])
    except Exception as e:
        print(f"GNB failed: {e}")

    return results


def print_metrics_summary(metrics, representation_name):
    """Print a concise summary of classification accuracies for a single evaluation."""
    if not metrics:
        print(f"No metrics for {representation_name}")
        return

    parts = []
    if 'logistic_acc' in metrics:
        parts.append(f"LogReg Acc={metrics['logistic_acc']:.4f}")
    if 'knn_acc' in metrics:
        parts.append(f"KNN Acc={metrics['knn_acc']:.4f}")
    if 'gnb_acc' in metrics:
        parts.append(f"GNB Acc={metrics['gnb_acc']:.4f}")
    # Macro F1 scores (when available)
    if 'logistic_f1' in metrics:
        parts.append(f"LogReg F1={metrics['logistic_f1']:.4f}")
    if 'knn_f1' in metrics:
        parts.append(f"KNN F1={metrics['knn_f1']:.4f}")
    if 'gnb_f1' in metrics:
        parts.append(f"GNB F1={metrics['gnb_f1']:.4f}")

    if parts:
        print(f"Results for {representation_name}: {', '.join(parts)}")
    else:
        print(f"Results for {representation_name}: (no accuracy scores available)")

def extract_nyu_scene_labels(mat_file_path):
    """
    Extract scene type labels from NYU Depth V2 .mat file.
    Returns array of string labels and integer encoded labels.
    """
    print("Extracting NYU Depth V2 scene labels...")
    
    with h5py.File(mat_file_path, 'r') as f:
        # Scene types are stored as references in HDF5 because they are strings
        raw_scenes = []
        ref_array = f['sceneTypes'][0]  # Shape (1449, 1)
        
        for ref in ref_array:
            # Dereference the object reference to get the character array
            obj = f[ref]
            # Convert ascii codes to string
            scene_str = ''.join(chr(c) for c in obj[:].flatten())
            raw_scenes.append(scene_str)
    
    print(f"Extracted {len(raw_scenes)} scene labels")
    print(f"Unique scenes: {sorted(set(raw_scenes))}")
    
    # Encode string labels to integers
    from sklearn.preprocessing import LabelEncoder
    label_encoder = LabelEncoder()
    scene_labels_int = label_encoder.fit_transform(raw_scenes)
    
    return np.array(raw_scenes), scene_labels_int, label_encoder


def load_nyu_raw_images(data_root, indices=None):
    """
    Load raw NYU Depth V2 images and depth maps for baseline evaluation.
    Args:
        data_root: Path to NYU dataset root
        indices: Optional array of indices to load (for alignment with representations)
    Returns:
        rgb_images, depth_maps as numpy arrays
    """
    print("Loading raw NYU Depth V2 images and depth maps...")
    
    from PIL import Image
    
    img_dir = os.path.join(data_root, "images")
    depth_dir = os.path.join(data_root, "depths")
    
    if indices is None:
        # Load all available images
        img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.png')])
        indices = [int(f.split('.')[0]) for f in img_files]
    
    rgb_list = []
    depth_list = []
    
    for idx in indices:
        # Load RGB
        rgb_path = os.path.join(img_dir, f"{idx:04d}.png")
        rgb = Image.open(rgb_path).convert('RGB')
        rgb = np.array(rgb).transpose(2, 0, 1) / 255.0  # (C, H, W) in [0, 1]
        rgb_list.append(rgb)
        
        # Load Depth
        depth_path = os.path.join(depth_dir, f"{idx:04d}_depth.png")
        depth = Image.open(depth_path)
        depth = np.array(depth).astype(np.float32) / 1000.0  # Convert mm to meters
        depth = depth[np.newaxis, :, :]  # Add channel dimension
        depth_list.append(depth)
    
    rgb_array = np.stack(rgb_list, axis=0)
    depth_array = np.stack(depth_list, axis=0)
    
    print(f"Loaded {len(indices)} samples: RGB {rgb_array.shape}, Depth {depth_array.shape}")
    
    return rgb_array, depth_array

def main():
    parser = argparse.ArgumentParser(description="Evaluate NYU Depth V2 Baselines for Scene Classification")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--results_dir', type=str, default=project_config.NYU_RESULTS_DIR)
    parser.add_argument('--data_dir', type=str, default=project_config.MM_BENCHMARKS_DATA_ROOT)
    parser.add_argument('--n_samples', type=int, default=None, help="Number of samples to use (None = all)")
    args = parser.parse_args()

    print("="*80)
    print(f"Evaluating NYU Depth V2 Baseline Methods for Scene Classification (seed={args.seed})")
    print("="*80)

    # --- 1. Extract Scene Labels ---
    mat_file = os.path.join(args.data_dir, "nyu_depth_v2_labeled.mat")
    scene_labels_str, scene_labels_int, label_encoder = extract_nyu_scene_labels(mat_file)
    print(f"Number of unique scenes: {len(label_encoder.classes_)}")
    print(f"Scene classes: {list(label_encoder.classes_)}")

    # --- 2. Load Raw Data for Baselines ---
    print("\n--- Loading Raw Data for Baseline Evaluation ---")
    
    # Determine indices to use
    if args.n_samples is not None:
        # Use first n_samples
        indices = list(range(args.n_samples))
    else:
        # Use all available samples
        indices = None
    
    raw_rgb, raw_depth = load_nyu_raw_images(args.data_dir, indices=indices)
    
    # Get corresponding labels
    if indices is None:
        raw_labels = scene_labels_int
    else:
        raw_labels = scene_labels_int[indices]
    
    print(f"Using {len(raw_labels)} samples for evaluation")
    print(f"Labels shape: {raw_labels.shape} ({len(np.unique(raw_labels))} unique classes)")

    all_results = []

    # --- 3. Evaluate Baselines on Raw Data ---
    #print("\n--- Evaluating Baselines on Raw Data ---")
    
    # Flatten for baselines
    raw_rgb_flat = raw_rgb.reshape(raw_rgb.shape[0], -1)
    raw_depth_flat = raw_depth.reshape(raw_depth.shape[0], -1)

    # 3a. Raw Data Baselines
    #print(f"Evaluating Raw RGB (Dim: {raw_rgb_flat.shape[1]})...")
    #m_rgb = compute_metrics(raw_rgb_flat, raw_labels)
    #print_metrics_summary(m_rgb, 'Raw_RGB')
    #m_rgb.update({'representation': 'Raw_RGB', 'dim': raw_rgb_flat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_rgb)
    
    #print(f"Evaluating Raw Depth (Dim: {raw_depth_flat.shape[1]})...")
    #m_depth = compute_metrics(raw_depth_flat, raw_labels)
    #print_metrics_summary(m_depth, 'Raw_Depth')
    #m_depth.update({'representation': 'Raw_Depth', 'dim': raw_depth_flat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_depth)
    
    #print(f"Evaluating Raw Concatenated (Dim: {raw_rgb_flat.shape[1] + raw_depth_flat.shape[1]})...")
    #raw_concat = np.concatenate([raw_rgb_flat, raw_depth_flat], axis=1)
    #m_cat = compute_metrics(raw_concat, raw_labels)
    #print_metrics_summary(m_cat, 'Raw_Concat')
    #m_cat.update({'representation': 'Raw_Concat', 'dim': raw_concat.shape[1], 'type': 'Baseline'})
    #all_results.append(m_cat)

    # 3b. Decomposition Baselines
    baseline_methods = [
        #('JIVE', lambda: JIVE()),
        #('AJIVE', lambda: AJIVE()),
        #('SLIDE', lambda: SLIDE()),
        ('ShIndICA', lambda: ShIndICA(joint_rank_options=[5, 10, 20, 50])),
    ]
    
    data_for_baselines = [
        torch.FloatTensor(raw_rgb_flat),
        torch.FloatTensor(raw_depth_flat)
    ]
    
    for method_name, constructor in baseline_methods:
        print(f"\nRunning {method_name}...")
        try:
            method = constructor()
            # These methods expect list of tensors
            decomposed, rank_info = method.decompose(data_for_baselines)
            
            print(f"{method_name} subspaces: {list(decomposed.keys())}")
            print(f"{method_name} rank info: {rank_info}")
            
            # Evaluate each subspace
            for key, val in decomposed.items():
                sub_rep = val.cpu().numpy() if torch.is_tensor(val) else val
                print(f"  Evaluating {method_name} {key} (Dim: {sub_rep.shape[1]})...")
                m = compute_metrics(sub_rep, raw_labels)
                print_metrics_summary(m, f"{method_name}_{key}")
                m.update({'representation': f"{method_name}_{key}", 'dim': sub_rep.shape[1], 'type': 'Baseline'})
                all_results.append(m)
                
            # Concatenate all subspaces
            all_subs = [val.cpu().numpy() if torch.is_tensor(val) else val for val in decomposed.values()]
            concat_subs = np.concatenate(all_subs, axis=1)
            print(f"  Evaluating {method_name} Concat (Dim: {concat_subs.shape[1]})...")
            m = compute_metrics(concat_subs, raw_labels)
            print_metrics_summary(m, f"{method_name}_Concat")
            m.update({'representation': f"{method_name}_Concat", 'dim': concat_subs.shape[1], 'type': 'Baseline'})
            all_results.append(m)
            
        except Exception as e:
            print(f"  {method_name} failed: {e}")

    # --- 4. Save & Print Results ---
    print("\n" + "="*100)
    print(f"{'Representation':<30} {'Dim':<6} {'LogReg F1':<10} {'LogReg Bal':<10} {'KNN F1':<10}")
    print("-" * 100)
    
    for res in all_results:
        name = res.get('representation', 'N/A')
        dim = res.get('dim', 0)
        
        # Format scores
        lr_f1 = f"{res.get('logistic_f1', 0):.4f}"
        lr_bal = f"{res.get('logistic_bal_acc', 0):.4f}"
        knn_f1 = f"{res.get('knn_f1', 0):.4f}"
        
        print(f"{name:<30} {dim:<6} {lr_f1:<10} {lr_bal:<10} {knn_f1:<10}")

    # Save CSV
    df = pd.DataFrame(all_results)
    save_path = f"{args.results_dir}/nyu_depth_baselines_scene_classification_seed{args.seed}.csv"
    df.to_csv(save_path, index=False)
    print(f"\nResults saved to {save_path}")
    print(f"\nBaseline Evaluation Complete!")
    print(f"  - Evaluated {len(all_results)} baseline configurations")
    print(f"  - {len(np.unique(raw_labels))} scene classes")
    print(f"  - {len(raw_labels)} samples")

if __name__ == '__main__':
    main()