

import sys
from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
    sys.path.append(project_root)
import project_config

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchgeo.datasets import So2Sat
import numpy as np
import os
import argparse
from tqdm import tqdm

# ==========================================
# 1. SIMPLE CNN AUTOENCODER
# ==========================================
class UnimodalAE(nn.Module):
    def __init__(self, input_channels, latent_dim=200, width=0.5):
        super().__init__()
        
        # Base channels
        base = max(16, int(32 * width))
        
        # --- Encoder ---
        # 32x32 -> 16x16 -> 8x8
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, base, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base),
            
            nn.Conv2d(base, base*2, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base*2),
            
            nn.Flatten(),
            nn.Linear(base*2 * 8 * 8, latent_dim)
        )
        
        # --- Decoder ---
        self.decoder_fc = nn.Linear(latent_dim, base*2 * 8 * 8)
        
        self.decoder_conv = nn.Sequential(
            nn.Unflatten(1, (base*2, 8, 8)),
            
            nn.ConvTranspose2d(base*2, base, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(base),
            
            nn.ConvTranspose2d(base, input_channels, 3, stride=2, padding=1, output_padding=1),
            # No activation on final layer (linear reconstruction for normalized data)
        )
        
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder_conv(self.decoder_fc(z))
        return x_hat, z

# ==========================================
# 2. DATA PREPARATION (Matches Multimodal)
# ==========================================
def load_and_process_so2sat(root_dir, modality, n_samples=None, seed=42):
    print(f"\n--- Loading So2Sat ({modality.upper()}) ---")
    
    # 1. Load Raw Data
    train_ds = So2Sat(root=root_dir, version="2", split="train", checksum=False)
    val_ds = So2Sat(root=root_dir, version="2", split="validation", checksum=False)
    
    def extract_tensors(dataset, limit=None):
        data_list = []
        labels_list = []
        count = 0
        loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)
        
        print("Extracting tensors from TorchGeo dataset...")
        for batch in tqdm(loader):
            imgs = batch["image"]
            lbls = batch["label"]
            
            if modality == 'radar':
                # Indices 0-8 are Sentinel-1
                x = imgs[:, 0:8, :, :]
            else:
                # Indices 8-11 are RGB (Sentinel-2)
                x = imgs[:, 8:11, :, :]
                
            data_list.append(x)
            labels_list.append(lbls)
            
            count += x.shape[0]
            if limit and count >= limit:
                break
                
        return torch.cat(data_list, dim=0), torch.cat(labels_list, dim=0)

    # Extract
    train_x, train_y = extract_tensors(train_ds, limit=n_samples)
    val_x, val_y = extract_tensors(val_ds, limit=int(n_samples*0.2) if n_samples else None)
    
    # 2. Preprocessing (CRITICAL: Must match your multimodal script)
    print("\nPreprocessing...")
    
    # when are nan values introduced?
    print(f"NaNs in train_x before log1p: {torch.isnan(train_x).sum().item()}")
    print(f"NaNs in val_x before log1p: {torch.isnan(val_x).sum().item()}")
    # ranges before processing
    print(f"Train x range before processing: min {train_x.min().item():.4f}, max {train_x.max().item():.4f}")
    print(f"Val x range before processing: min {val_x.min().item():.4f}, max {val_x.max().item():.4f}")
    if modality == 'radar':
        # if there are negative values, rescale the values to be within [0, max] per channel
        min_vals = torch.cat((train_x, val_x), dim=0).amin(dim=(0, 2, 3), keepdim=True)
        train_x = train_x - min_vals
        val_x = val_x - min_vals
        print(f"After min subtraction, Train x range: min {train_x.min().item():.4f}, max {train_x.max().item():.4f}")
        print(f"After min subtraction, Val x range: min {val_x.min().item():.4f}, max {val_x.max().item():.4f}")
        # Log Transform + Clip
        train_x = torch.log1p(train_x)
        val_x = torch.log1p(val_x)
        print(f"NaNs in train_x after log1p: {torch.isnan(train_x).sum().item()}")
        print(f"NaNs in val_x after log1p: {torch.isnan(val_x).sum().item()}")
        
        flat_data = train_x.flatten()
        if flat_data.numel() > 1000000:
            indices = torch.randperm(flat_data.numel())[:1000000]
            q99 = torch.quantile(flat_data[indices], 0.99)
        else:
            q99 = torch.quantile(flat_data, 0.99)
        #q99 = torch.quantile(train_x, 0.99)
        train_x = torch.clamp(train_x, max=q99)
        val_x = torch.clamp(val_x, max=q99)
        print(f"Radar: Log1p + Clipped at {q99:.4f}")
        print(f"NaNs in train_x after clipping: {torch.isnan(train_x).sum().item()}")
        print(f"NaNs in val_x after clipping: {torch.isnan(val_x).sum().item()}")
        assert not torch.isnan(train_x).any(), "NaNs found in train_x after radar preprocessing!"
        assert not torch.isnan(val_x).any(), "NaNs found in val_x after radar preprocessing!"
        
    # Z-Normalize (Per channel stats from training set)
    mean = train_x.mean(dim=(0, 2, 3), keepdim=True)
    std = train_x.std(dim=(0, 2, 3), keepdim=True)
    
    train_x = (train_x - mean) / (std + 1e-8)
    val_x = (val_x - mean) / (std + 1e-8)
    
    print(f"Data Normalized. Train shape: {train_x.shape}")
    
    return train_x.float(), train_y, val_x.float(), val_y

# ==========================================
# 3. TRAINING LOOP
# ==========================================
def train(model, train_loader, val_loader, device, epochs=50, lr=1e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    best_val_loss = float('inf')
    best_state = None
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0.0
        for x, _ in train_loader:
            x = x.to(device)
            optimizer.zero_grad()
            x_hat, z = model(x)
            loss = criterion(x_hat, x)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        avg_train = train_loss / len(train_loader)
        
        # Val
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, _ in val_loader:
                x = x.to(device)
                x_hat, _ = model(x)
                loss = criterion(x_hat, x)
                val_loss += loss.item()
                
        avg_val = val_loss / len(val_loader)
        
        print(f"Epoch {epoch+1:03d} | Train Loss: {avg_train:.5f} | Val Loss: {avg_val:.5f}")
        
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            best_state = model.state_dict()
            
    return best_state

# ==========================================
# 4. MAIN
# ==========================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--modality', type=str, required=True, choices=['radar', 'optical'])
    parser.add_argument('--data_dir', type=str, default=project_config.SO2SAT_DATA_ROOT)
    parser.add_argument('--save_dir', type=str, default=project_config.SO2SAT_RESULTS_DIR)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--n_samples', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--latent_dim', type=int, default=200) # Adjust to match your MM model
    args = parser.parse_args()
    
    # Setup
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.save_dir, exist_ok=True)
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Load Data
    X_train, y_train, X_val, y_val = load_and_process_so2sat(
        args.data_dir, args.modality, args.n_samples, args.seed
    )
    
    # Create Loaders
    train_ds = TensorDataset(X_train, y_train)
    val_ds = TensorDataset(X_val, y_val)
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)
    
    # Init Model
    in_channels = 8 if args.modality == 'radar' else 3
    model = UnimodalAE(in_channels, args.latent_dim).to(device)
    
    print(f"\nStarting training for {args.modality}...")
    best_weights = train(model, train_loader, val_loader, device, args.epochs)
    
    # Save Model
    model.load_state_dict(best_weights)
    model_path = os.path.join(args.save_dir, f"{args.modality}_ae_seed{args.seed}.pth")
    torch.save(best_weights, model_path)
    print(f"Model saved to {model_path}")
    
    # Extract Representations (Train set)
    print("Extracting representations...")
    model.eval()
    reps = []
    
    # Use a non-shuffled loader for extraction
    extract_loader = DataLoader(train_ds, batch_size=256, shuffle=False)
    
    with torch.no_grad():
        for x, _ in extract_loader:
            x = x.to(device)
            _, z = model(x)
            reps.append(z.cpu().numpy())
            
    reps = np.concatenate(reps, axis=0)
    
    # Save Reps
    rep_path = os.path.join(args.save_dir, f"{args.modality}_reps_seed{args.seed}.npy")
    np.save(rep_path, reps)
    
    # Save Labels (for consistency)
    label_path = os.path.join(args.save_dir, f"train_labels_seed{args.seed}.npy")
    np.save(label_path, y_train.numpy())
    
    print(f"Representations saved to {rep_path}")
    print(f"Shape: {reps.shape}")