import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import os
import logging
import time
from tqdm import tqdm
import argparse

from contrastive_mapping_model_final import EnhancedMLPWithAttention, SupervisedContrastiveLossWithKernel

os.chdir(os.path.dirname(os.path.abspath(__file__)))
#print(f"Current working directory: {os.getcwd()}")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants for filenames ---
X_TRAIN_FILENAME = "X_train.npy"
Y_TRAIN_FILENAME = "y_train.npy"
X_VAL_FILENAME = "X_val.npy"
Y_VAL_FILENAME = "y_val.npy"
X_TEST_FILENAME = "X_test.npy"
Y_TEST_FILENAME = "y_test.npy"


def parse_arguments():
    parser = argparse.ArgumentParser(description="Contrastive Model Training Script")

    group_data = parser.add_argument_group('Data Parameters')
    group_data.add_argument('--data_prefix', type=str, default="", help="Prefix for data file paths.")
    group_data.add_argument('--num_classes', type=int, default=5, help="Number of classes for classification.")

    group_model = parser.add_argument_group('Model Parameters')
    group_model.add_argument('--input_dim', type=int, default=1024, help="Input embedding dimension.")
    group_model.add_argument('--hidden_dims', type=int, nargs='+', default=[512, 256], help="MLP hidden layer dimensions.")
    group_model.add_argument('--output_dim', type=int, default=64, help="Optimized embedding dimension (model output).")
    group_model.add_argument('--alpha_mix', type=float, default=0.3, help="Mixing parameter for MLP and attention features in EnhancedMLPWithAttention.")

    group_loss = parser.add_argument_group('Loss Function Parameters')
    group_loss.add_argument('--temperature', type=float, default=0.1, help="Temperature for contrastive loss.")
    group_loss.add_argument('--lambda_offset', type=float, default=0.1, help="Lambda offset parameter (currently not directly used in loss logic shown but kept for consistency).")
    group_loss.add_argument('--enable_magnitude_loss', action='store_true', help="Enable magnitude loss component.")
    group_loss.add_argument('--w_contrastive', type=float, default=1.0, help="Weight for contrastive loss.")
    group_loss.add_argument('--w_offset', type=float, default=1.0, help="Weight for offset loss.")
    group_loss.add_argument('--w_class', type=float, default=3.0, help="Weight for classification loss.")
    group_loss.add_argument('--w_ortho', type=float, default=5.0, help="Weight for orthogonality loss.")
    group_loss.add_argument('--w_magnitude_value', type=float, default=1.0, help="Weight value for magnitude loss if enabled.")

    group_train = parser.add_argument_group('Training Parameters')
    group_train.add_argument('--learning_rate', type=float, default=5e-5, help="Learning rate.")
    group_train.add_argument('--batch_size', type=int, default=256, help="Batch size.")
    group_train.add_argument('--max_epochs', type=int, default=100, help="Maximum number of training epochs.")
    group_train.add_argument('--patience', type=int, default=10, help="Patience for early stopping.")
    group_train.add_argument('--weight_decay', type=float, default=1e-5, help="Weight decay for AdamW optimizer.")
    
    group_output = parser.add_argument_group('Output Parameters')
    group_output.add_argument('--model_save_path', type=str, default="contrastive_model.pth", help="Path to save the trained model.")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_arguments()

    # --- Path Setup from args ---
    X_TRAIN_PATH = f"{args.data_prefix}{X_TRAIN_FILENAME}"
    Y_TRAIN_PATH = f"{args.data_prefix}{Y_TRAIN_FILENAME}"
    X_VAL_PATH = f"{args.data_prefix}{X_VAL_FILENAME}"
    Y_VAL_PATH = f"{args.data_prefix}{Y_VAL_FILENAME}"
    # X_TEST_PATH = f"{args.data_prefix}{X_TEST_FILENAME}" # Not used in this training script
    # Y_TEST_PATH = f"{args.data_prefix}{Y_TEST_FILENAME}"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    logger.info(f"Running with arguments: {args}")


    logger.info("Loading pre-processed data for contrastive training...")
    try:
        X_train = np.load(X_TRAIN_PATH)
        y_train = np.load(Y_TRAIN_PATH)
        X_val = np.load(X_VAL_PATH)
        y_val = np.load(Y_VAL_PATH)
        logger.info("Data loaded successfully.")
        logger.info(f"Train shapes: X={X_train.shape}, y={y_train.shape}")
        logger.info(f"Val shapes:   X={X_val.shape}, y={y_val.shape}")
    except FileNotFoundError as e:
        logger.error(f"Error loading data files: {e}. Please ensure .npy files exist at specified paths (prefix: '{args.data_prefix}').")
        exit()


    logger.info("Calculating class weights...")
    try:
        classes = np.unique(y_train)
        if not (len(classes) == args.num_classes or (args.num_classes == 1 and len(classes) <=2 )): # for num_classes=1 (regression-like) or actual num_classes
             logger.warning(f"Number of unique classes in y_train ({len(classes)}) does not match --num_classes ({args.num_classes}). Using detected classes for weight calculation if possible.")
        
        if args.num_classes == 1 and len(classes) > 1 : # Special case for regression-like with multiple "categories" in y
             logger.info("num_classes is 1, but multiple values in y_train detected. Class weights will be uniform.")
             class_weights_tensor = torch.ones(args.num_classes, dtype=torch.float32) # This case is tricky for 'balanced'
        elif len(classes) == 0 :
            logger.error("No classes found in y_train. Cannot compute class weights.")
            raise ValueError("Empty y_train")
        else:
            # Ensure classes for compute_class_weight are correctly derived if num_classes might be less than actual unique labels seen
            # This typically happens if y_train doesn't contain all possible classes.
            # For safety, use discovered classes unless num_classes is strictly defined and data must adhere.
            effective_classes_for_weighting = classes if len(classes) > 0 else np.arange(args.num_classes)

            weights = compute_class_weight(class_weight='balanced', classes=effective_classes_for_weighting, y=y_train)
            
            # If compute_class_weight returns fewer weights than num_classes due to missing classes in y_train
            # we need to map them correctly or ensure weights are for args.num_classes
            if len(weights) != args.num_classes:
                logger.warning(f"Computed weights length ({len(weights)}) for classes {effective_classes_for_weighting} "
                               f"differs from --num_classes ({args.num_classes}). "
                               "This can happen if not all classes are present in y_train. Adjusting carefully.")
                temp_weights = torch.ones(args.num_classes, dtype=torch.float32)
                # Map computed weights to the correct class indices if possible
                # This part is complex if classes are not 0 to N-1.
                # For simplicity, if lengths mismatch, using uniform or erroring might be safer.
                # Here, we'll try a simple mapping assuming classes are 0-indexed from what was found.
                # A more robust solution would require knowing the full class set.
                if len(effective_classes_for_weighting) == len(weights):
                    for i, cls_idx in enumerate(effective_classes_for_weighting):
                        if cls_idx < args.num_classes: # Check bounds
                           temp_weights[int(cls_idx)] = weights[i]
                    class_weights_tensor = temp_weights
                else: # Fallback or more complex logic needed
                    logger.error("Mismatch in class mapping for weights. Using uniform weights as a fallback.")
                    class_weights_tensor = torch.ones(args.num_classes, dtype=torch.float32)

            else:
                 class_weights_tensor = torch.tensor(weights, dtype=torch.float32)

        logger.info(f"Calculated Class Weights (for {args.num_classes} classes): {class_weights_tensor.numpy()}")

    except Exception as e:
        logger.error(f"Error calculating class weights: {e}")
        logger.warning(f"Using default uniform weights for {args.num_classes} classes.")
        class_weights_tensor = torch.ones(args.num_classes, dtype=torch.float32)


    train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
    val_dataset = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)


    logger.info("Initializing model and loss function...")
    model = EnhancedMLPWithAttention(
        input_dim=args.input_dim,
        hidden_dims=args.hidden_dims,
        output_dim=args.output_dim,
        num_classes=args.num_classes,
        alpha_mix=args.alpha_mix
    ).to(device)

    criterion = SupervisedContrastiveLossWithKernel(
        num_classes=args.num_classes,
        embedding_dim=args.output_dim,
        temperature=args.temperature,
        lambda_offset=args.lambda_offset,
        class_weights=class_weights_tensor,
        w_contrastive=args.w_contrastive,
        w_offset=args.w_offset,
        w_class=args.w_class,
        w_ortho=args.w_ortho,
        enable_magnitude_loss=args.enable_magnitude_loss,
        w_magnitude_value=args.w_magnitude_value
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=args.patience // 2, verbose=True) # Adjusted scheduler patience


    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    logger.info("Starting contrastive model training...")
    training_start_time = time.time()

    for epoch in range(args.max_epochs):
        model.train()
        total_train_loss = 0.0
        train_loss_components_agg = {"contrastive": 0.0, "offset": 0.0, "classification": 0.0, "orthogonality": 0.0, "magnitude": 0.0}
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.max_epochs} [Train]", leave=False)

        for embeddings, labels in train_progress_bar:
            embeddings, labels = embeddings.to(device), labels.to(device)
            optimizer.zero_grad()
            model_output = model(embeddings)
            loss, loss_components = criterion(model_output, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            for key in train_loss_components_agg:
                if key in loss_components:
                    train_loss_components_agg[key] += loss_components[key]

            train_progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        for key in train_loss_components_agg:
            train_loss_components_agg[key] /= len(train_loader)


        model.eval()
        total_val_loss = 0.0
        val_loss_components_agg = {"contrastive": 0.0, "offset": 0.0, "classification": 0.0, "orthogonality": 0.0, "magnitude": 0.0}
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.max_epochs} [Val]", leave=False)

        with torch.no_grad():
            for embeddings, labels in val_progress_bar:
                embeddings, labels = embeddings.to(device), labels.to(device)
                model_output = model(embeddings)
                loss, loss_components = criterion(model_output, labels)
                total_val_loss += loss.item()
                for key in val_loss_components_agg:
                     if key in loss_components:
                        val_loss_components_agg[key] += loss_components[key]

                val_progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        avg_val_loss = total_val_loss / len(val_loader)
        for key in val_loss_components_agg:
            val_loss_components_agg[key] /= len(val_loader)

        logger.info(f"Epoch {epoch+1}/{args.max_epochs} Summary:")
        logger.info(f"  Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        train_comps_str = ", ".join([f"{k.capitalize()[:4]}={v:.3f}" for k, v in train_loss_components_agg.items()])
        val_comps_str = ", ".join([f"{k.capitalize()[:4]}={v:.3f}" for k, v in val_loss_components_agg.items()])
        logger.info(f"  Train Components: {train_comps_str}")
        logger.info(f"  Val Components:   {val_comps_str}")


        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            logger.info(f"Validation loss improved to {best_val_loss:.4f}. Saving model state to '{args.model_save_path}'...")
            torch.save(best_model_state, args.model_save_path)
        else:
            patience_counter += 1
            logger.info(f"No improvement in validation loss for {patience_counter} epochs (Best: {best_val_loss:.4f}).")
            if patience_counter >= args.patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}.")
                break

    training_end_time = time.time()
    logger.info(f"Contrastive training finished. Total time: {training_end_time - training_start_time:.2f} seconds.")

    if best_model_state is not None:
        logger.info(f"Final best model state saved to '{args.model_save_path}' (Val Loss: {best_val_loss:.4f})")
    else:
        logger.error("Training did not result in a best model state (validation loss never improved). No model saved.")

    logger.info("Script finished.")