

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import torch
import torch.nn as nn
import sys
import os
import argparse
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from torchgeo.datasets import So2Sat
from tqdm import tqdm

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.models.larrp_unimodal import AdaptiveRankReducedLinear
from src.functions.train_larrp_multimodal_geo import train_overcomplete_ae_with_pretrained
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import Pipeline

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train Model on So2Sat dataset")
    parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    parser.add_argument('--n_samples', type=int, default=None, help='Number of samples to use from the dataset.')
    parser.add_argument('--normalize', action='store_true', help='Whether to normalize the data.')
    parser.add_argument('--prefix', type=str, default='so2sat', help='Prefix for saving results.')
    args = parser.parse_args()

    # --- 1. Load Data from So2Sat Dataset ---
    print("Loading data from So2Sat dataset...")
    
    # Load So2Sat datasets
    train_dataset = So2Sat(
        root=project_config.SO2SAT_DATA_ROOT, 
        version="2",
        split="train", 
        transforms=None,
        checksum=False
    )
    
    val_dataset = So2Sat(
        root=project_config.SO2SAT_DATA_ROOT, 
        version="2",
        split="validation", 
        transforms=None,
        checksum=False
    )
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    
    # --- 2. Extract modalities from So2Sat data ---
    print("\nExtracting radar and optical modalities...")
    
    # Process training data
    train_radar_list = []
    train_optical_list = []
    train_labels_list = []
    
    train_samples = 0
    # add progress bar and sample count
    for batch in tqdm(DataLoader(train_dataset, batch_size=256, shuffle=False), desc="Processing training data"):
        image = batch["image"]  # Shape: (B, 18, 32, 32)
        label = batch["label"]  # Shape: (B,)
        
        # Split into modalities
        # First 8 bands are S1 (radar), last 10 bands are S2 (optical) but we only want RGB
        radar = image[:, 0:8, :, :]  # Keep as (B, 8, 32, 32)
        optical = image[:, 8:11, :, :]  # RGB only
        
        train_radar_list.append(radar)
        train_optical_list.append(optical)
        train_labels_list.append(label)
        train_samples += radar.shape[0]
        if (args.n_samples is not None) and (train_samples >= args.n_samples):
            break
        # update progress bar with number of samples processed
        #tqdm.write(f"Processed {train_samples} training samples.")
    
    train_radar = torch.cat(train_radar_list, dim=0)
    train_optical = torch.cat(train_optical_list, dim=0)
    train_labels = torch.cat(train_labels_list, dim=0)
    # how many samples are there per class (in percent)?
    unique, counts = torch.unique(train_labels, return_counts=True)
    class_distribution = {int(u): float(c) / len(train_labels) * 100 for u, c in zip(unique, counts)}
    print("Training class distribution (percent):")
    for cls, pct in class_distribution.items():
        print(f"  Class {cls}: {pct:.2f}%")
    
    # Process validation data
    val_radar_list = []
    val_optical_list = []
    val_labels_list = []
    
    val_samples = 0
    for batch in DataLoader(val_dataset, batch_size=256, shuffle=False):
        image = batch["image"]
        label = batch["label"]
        
        radar = image[:, 0:8, :, :]  # Keep as (B, 8, 32, 32)
        #optical = image[:, 8:18, :, :]  # Keep as (B, 10, 32, 32)
        optical = image[:, 8:11, :, :]  # RGB only
        
        val_radar_list.append(radar)
        val_optical_list.append(optical)
        val_labels_list.append(label)
        val_samples += radar.shape[0]
        if (args.n_samples is not None) and (val_samples >= int(0.1 * args.n_samples)):
            break
    
    val_radar = torch.cat(val_radar_list, dim=0)
    val_optical = torch.cat(val_optical_list, dim=0)
    val_labels = torch.cat(val_labels_list, dim=0)
    
    print(f"\nFinal dataset shapes:")
    print(f"Train Radar: {train_radar.shape}")
    print(f"Train Optical: {train_optical.shape}")
    print(f"Train Labels: {train_labels.shape}")
    print(f"Val Radar: {val_radar.shape}")
    print(f"Val Optical: {val_optical.shape}")
    print(f"Val Labels: {val_labels.shape}")

    # value ranges
    print(f"\nValue ranges:")
    print(f"Train Radar: min={train_radar.min().item()}, max={train_radar.max().item()}")
    print(f"Train Optical: min={train_optical.min().item()}, max={train_optical.max().item()}")
    print(f"Val Radar: min={val_radar.min().item()}, max={val_radar.max().item()}")
    print(f"Val Optical: min={val_optical.min().item()}, max={val_optical.max().item()}")

    # --- 3. Prepare Data for the Model ---
    # Normalize using training statistics only
    train_data = [train_radar.float(), train_optical.float()]
    val_data = [val_radar.float(), val_optical.float()]
    
    if args.normalize:
        print("\nNormalizing data (z-score per channel using training statistics)...")
        for i in range(len(train_data)):
            if i == 0: # radar modality: log transform
                min_vals = torch.cat((train_data[i], val_data[i]), dim=0).amin(dim=(0, 2, 3), keepdim=True)
                if (torch.any(min_vals < 0)):
                    train_data[i] = train_data[i] - min_vals
                    val_data[i] = val_data[i] - min_vals
                    print(f"Modality {i} min subtraction done.")
                # Add small constant to avoid log(0)
                train_data[i] = torch.log1p(train_data[i])
                val_data[i] = torch.log1p(val_data[i])
                print(f"Modality {i} log-transformed.")
                # clip outliers - compute quantile on flattened subset to avoid memory issues
                # Sample at most 10M values for quantile computation
                flat_data = train_data[i].flatten()
                if flat_data.numel() > 1000000:
                    indices = torch.randperm(flat_data.numel())[:1000000]
                    q99 = torch.quantile(flat_data[indices], 0.99)
                else:
                    q99 = torch.quantile(flat_data, 0.99)
                train_data[i] = torch.clamp(train_data[i], max=q99)
                val_data[i] = torch.clamp(val_data[i], max=q99)
                print(f"Modality {i} clipped at 99th percentile: {q99.item():.4f}")
                # now z-normalize
                mean = train_data[i].mean(dim=(0, 2, 3), keepdim=True)
                std = train_data[i].std(dim=(0, 2, 3), keepdim=True)
                train_data[i] = (train_data[i] - mean) / (std + 1e-8)
                val_data[i] = (val_data[i] - mean) / (std + 1e-8)
                print(f"Modality {i} normalized per channel:")
                for ch in range(mean.shape[1]):
                    print(f"  Channel {ch}: mean={mean[0, ch, 0, 0].item():.4f}, std={std[0, ch, 0, 0].item():.4f}")
            if i == 1: # z-normalization only for optical modality
                # Compute mean and std per channel from training data only
                # mean/std shape: (1, C, 1, 1)
                mean = train_data[i].mean(dim=(0, 2, 3), keepdim=True)
                std = train_data[i].std(dim=(0, 2, 3), keepdim=True)
                # Apply to both train and val
                train_data[i] = (train_data[i] - mean) / (std + 1e-8)
                val_data[i] = (val_data[i] - mean) / (std + 1e-8)
                print(f"Modality {i} normalized per channel:")
                for ch in range(mean.shape[1]):
                    print(f"  Channel {ch}: mean={mean[0, ch, 0, 0].item():.4f}, std={std[0, ch, 0, 0].item():.4f}")
    # check that there are no NaNs
    for i in range(len(train_data)):
        assert not torch.isnan(train_data[i]).any(), f"NaNs found in train_data modality {i} after preprocessing!"
        assert not torch.isnan(val_data[i]).any(), f"NaNs found in val_data modality {i} after preprocessing!"
    
    train_labels_data = train_labels
    val_labels_data = val_labels
    
    # Define input shapes for CNN: (channels, height, width)
    input_shapes = [(8, 32, 32), (3, 32, 32)]  # radar and optical
    
    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {DEVICE}")
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    # --- Your Model Hyperparameters and Training Call ---
    rank_reduction_frequency = 10
    rank_reduction_threshold = 0.01
    early_stopping = 50
    patience = 10
    r_square_threshold = 0.05

    class Args:
        def __init__(self):
            self.latent_dim = 200
            self.batch_size = 512
            self.lr = 1e-4
            self.weight_decay = 2e-5
            #self.dropout = 0.1
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 2000
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()
    
    print("\nStarting model training...")
    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae_with_pretrained(
        train_data, 
        val_data,
        train_args.latent_dim, 
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        r_square_threshold=r_square_threshold,
        threshold_type='absolute',
        compressibility_type='direct',
        #verbose=True,
        compute_jacobian=False,
        sharedwhenall=False,
        pretrained_name=f"_{args.prefix}_rseed-{args.seed}",
        lr_schedule='cosine',
        end_lr=1e-7,
        decision_metric='ExVarScore',
        input_shapes=input_shapes
    )
    
    # Save results
    model_prefix = f"{args.prefix}_rseed-{args.seed}"
    save_dir = project_config.SO2SAT_RESULTS_DIR
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model.state_dict(), f"{save_dir}/{model_prefix}_final_model.pth")
    for i, rep in enumerate(reps):
        np.save(f"{save_dir}/{model_prefix}_rep{i}.npy", rep.cpu().numpy())
    pd.DataFrame(rank_history).to_csv(f"{save_dir}/{model_prefix}_rank_history.csv", index=False)
    
    # Save labels for later analysis
    np.save(f"{save_dir}/{model_prefix}_train_labels.npy", train_labels_data.numpy())
    np.save(f"{save_dir}/{model_prefix}_val_labels.npy", val_labels_data.numpy())
    
    # --- Classification Accuracy Evaluation ---
    print("\n" + "="*80)
    print("Computing classification accuracies...")
    print("="*80)
    
    def compute_multi_classifier_accs(rep, labels):
        """Compute accuracies for classifiers using 5-fold cross-validation.
        Returns a dict with keys: knn, logistic, gnb (values are floats or None).
        """
        out = {'knn': None, 'logistic': None, 'gnb': None}
        if rep is None or labels is None:
            return out
        try:
            X = np.asarray(rep)
            y = np.asarray(labels)
            # encode non-numeric labels
            if y.dtype.kind in {'U', 'S', 'O'}:
                le = LabelEncoder()
                y = le.fit_transform(y.astype(str))
            n = X.shape[0]
            if n < 10:
                return out
            
            # Use 5-fold cross-validation for all classifiers
            cv_folds = 5
            
            # KNN with 5-fold CV
            try:
                knn = KNeighborsClassifier(n_neighbors=1)
                scores = cross_val_score(knn, X, y, cv=cv_folds, scoring='accuracy')
                out['knn'] = float(np.mean(scores))
            except Exception as e:
                print(f"KNN CV failed: {e}")
                out['knn'] = None
            
            # Logistic Regression with 5-fold CV (with scaling)
            try:
                scaler = StandardScaler()
                log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                pipeline = Pipeline([('scaler', scaler), ('classifier', log)])
                scores = cross_val_score(pipeline, X, y, cv=cv_folds, scoring='accuracy')
                out['logistic'] = float(np.mean(scores))
            except Exception as e:
                print(f"Logistic CV failed: {e}")
                out['logistic'] = None
            
            # Gaussian NB with 5-fold CV (with scaling)
            try:
                scaler = StandardScaler()
                gnb = GaussianNB()
                pipeline = Pipeline([('scaler', scaler), ('classifier', gnb)])
                scores = cross_val_score(pipeline, X, y, cv=cv_folds, scoring='accuracy')
                out['gnb'] = float(np.mean(scores))
            except Exception as e:
                print(f"GaussianNB CV failed: {e}")
                out['gnb'] = None
            
            return out
        except Exception as e:
            print(f"Multi-classifier accuracy failed: {e}")
            return out
    
    # Prepare results list
    results = []
    
    # reps is a list: [shared, modality_0_specific, modality_1_specific]
    # Labels are from train_labels_data
    labels = train_labels_data.numpy()
    
    print(f"\nNumber of unique classes: {len(np.unique(labels))}")
    print(f"Number of samples: {len(labels)}")
    
    # Shared representation
    if len(reps) > 0 and reps[0] is not None:
        shared_rep = reps[0].cpu().numpy() if torch.is_tensor(reps[0]) else reps[0]
        shared_accs = compute_multi_classifier_accs(shared_rep, labels)
        print(f"Shared representation - GNB: {shared_accs['gnb']:.4f}, KNN: {shared_accs['knn']:.4f}, Logistic: {shared_accs['logistic']:.4f}" if shared_accs['gnb'] is not None else "Shared representation: Failed")
        results.append({
            'representation': 'shared',
            'accuracy': shared_accs['gnb'],
            'accuracy_knn': shared_accs['knn'],
            'accuracy_logistic': shared_accs['logistic'],
            'dim': shared_rep.shape[1]
        })
    
    # Modality 0 (Radar) specific representation
    if len(reps) > 1 and reps[1] is not None:
        mod0_rep = reps[1].cpu().numpy() if torch.is_tensor(reps[1]) else reps[1]
        mod0_accs = compute_multi_classifier_accs(mod0_rep, labels)
        print(f"Radar-specific representation - GNB: {mod0_accs['gnb']:.4f}, KNN: {mod0_accs['knn']:.4f}, Logistic: {mod0_accs['logistic']:.4f}" if mod0_accs['gnb'] is not None else "Radar-specific representation: Failed")
        results.append({
            'representation': 'radar_specific',
            'accuracy': mod0_accs['gnb'],
            'accuracy_knn': mod0_accs['knn'],
            'accuracy_logistic': mod0_accs['logistic'],
            'dim': mod0_rep.shape[1]
        })
    
    # Modality 1 (Optical) specific representation
    if len(reps) > 2 and reps[2] is not None:
        mod1_rep = reps[2].cpu().numpy() if torch.is_tensor(reps[2]) else reps[2]
        mod1_accs = compute_multi_classifier_accs(mod1_rep, labels)
        print(f"Optical-specific representation - GNB: {mod1_accs['gnb']:.4f}, KNN: {mod1_accs['knn']:.4f}, Logistic: {mod1_accs['logistic']:.4f}" if mod1_accs['gnb'] is not None else "Optical-specific representation: Failed")
        results.append({
            'representation': 'optical_specific',
            'accuracy': mod1_accs['gnb'],
            'accuracy_knn': mod1_accs['knn'],
            'accuracy_logistic': mod1_accs['logistic'],
            'dim': mod1_rep.shape[1]
        })
    
    # Concatenated: Shared + Radar-specific
    if len(reps) > 1 and reps[0] is not None and reps[1] is not None:
        shared_rep = reps[0].cpu().numpy() if torch.is_tensor(reps[0]) else reps[0]
        mod0_rep = reps[1].cpu().numpy() if torch.is_tensor(reps[1]) else reps[1]
        concat_radar = np.concatenate([shared_rep, mod0_rep], axis=1)
        concat_radar_accs = compute_multi_classifier_accs(concat_radar, labels)
        print(f"Shared+Radar concat - GNB: {concat_radar_accs['gnb']:.4f}, KNN: {concat_radar_accs['knn']:.4f}, Logistic: {concat_radar_accs['logistic']:.4f}" if concat_radar_accs['gnb'] is not None else "Shared+Radar concat: Failed")
        results.append({
            'representation': 'shared_radar_concat',
            'accuracy': concat_radar_accs['gnb'],
            'accuracy_knn': concat_radar_accs['knn'],
            'accuracy_logistic': concat_radar_accs['logistic'],
            'dim': concat_radar.shape[1]
        })
    
    # Concatenated: Shared + Optical-specific
    if len(reps) > 2 and reps[0] is not None and reps[2] is not None:
        shared_rep = reps[0].cpu().numpy() if torch.is_tensor(reps[0]) else reps[0]
        mod1_rep = reps[2].cpu().numpy() if torch.is_tensor(reps[2]) else reps[2]
        concat_optical = np.concatenate([shared_rep, mod1_rep], axis=1)
        concat_optical_accs = compute_multi_classifier_accs(concat_optical, labels)
        print(f"Shared+Optical concat - GNB: {concat_optical_accs['gnb']:.4f}, KNN: {concat_optical_accs['knn']:.4f}, Logistic: {concat_optical_accs['logistic']:.4f}" if concat_optical_accs['gnb'] is not None else "Shared+Optical concat: Failed")
        results.append({
            'representation': 'shared_optical_concat',
            'accuracy': concat_optical_accs['gnb'],
            'accuracy_knn': concat_optical_accs['knn'],
            'accuracy_logistic': concat_optical_accs['logistic'],
            'dim': concat_optical.shape[1]
        })
    
    # Concatenated: All (Shared + Radar + Optical)
    if len(reps) > 2 and reps[0] is not None and reps[1] is not None and reps[2] is not None:
        shared_rep = reps[0].cpu().numpy() if torch.is_tensor(reps[0]) else reps[0]
        mod0_rep = reps[1].cpu().numpy() if torch.is_tensor(reps[1]) else reps[1]
        mod1_rep = reps[2].cpu().numpy() if torch.is_tensor(reps[2]) else reps[2]
        concat_all = np.concatenate([shared_rep, mod0_rep, mod1_rep], axis=1)
        concat_all_accs = compute_multi_classifier_accs(concat_all, labels)
        print(f"All concat - GNB: {concat_all_accs['gnb']:.4f}, KNN: {concat_all_accs['knn']:.4f}, Logistic: {concat_all_accs['logistic']:.4f}" if concat_all_accs['gnb'] is not None else "All concat: Failed")
        results.append({
            'representation': 'all_concat',
            'accuracy': concat_all_accs['gnb'],
            'accuracy_knn': concat_all_accs['knn'],
            'accuracy_logistic': concat_all_accs['logistic'],
            'dim': concat_all.shape[1]
        })
    
    # Save classification results
    import csv
    csv_file = f"{save_dir}/{model_prefix}_classification_accuracies.csv"
    with open(csv_file, 'w', newline='') as f:
        fieldnames = ['representation', 'accuracy', 'accuracy_logistic', 'accuracy_knn', 'dim']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    
    print(f"\n✓ Saved classification results to: {csv_file}")
    
    # Print summary table
    print("\n" + "="*80)
    print("Classification Accuracy Summary:")
    print("="*80)
    print(f"{'Representation':<30} {'Acc (GNB)':<12} {'Acc (Logistic)':<15} {'Acc (KNN)':<12} {'Dim':<10}")
    print("-" * 79)
    for row in results:
        acc_gnb_str = f"{row['accuracy']:.4f}" if row.get('accuracy') is not None else "N/A"
        acc_logistic_str = f"{row.get('accuracy_logistic'):.4f}" if row.get('accuracy_logistic') is not None else "N/A"
        acc_knn_str = f"{row.get('accuracy_knn'):.4f}" if row.get('accuracy_knn') is not None else "N/A"
        print(f"{row['representation']:<30} {acc_gnb_str:<12} {acc_logistic_str:<15} {acc_knn_str:<12} {row['dim']:<10}")
    
    print("\n" + "="*80)
    print("Experiment completed! Model and results saved with prefix '{model_prefix}'")
    print("="*80)