import os
import sys
import gc
from typing import Dict, Tuple, Optional
from argparse import Namespace
from collections import Counter

import torch
import numpy as np
from optuna.trial import Trial

sys.path.append("..")
from common import utils
from codes.trainer import ModelTrainer

torch.backends.cudnn.deterministic = True

def get_class_weight(class_weight, train_labels):
    """
    Calculate or parse class weights for handling class imbalance.

    Args:
        class_weight: Class weight specification. Can be:
            - "auto": Automatically calculate inverse frequency weights
            - "manual-{value}": Use manually specified weight value
        train_labels: Array or list of training labels for calculating frequencies

    Returns:
        np.ndarray: Array of class weights

    Raises:
        ValueError: If class_weight format is not recognized
    """
    if class_weight == "auto":
        weight = utils.calc_class_weight(train_labels)
    elif class_weight.startswith("manual-"):
        weight = np.array([float(class_weight[7:])])
    else:
        raise ValueError
    return weight

def run_train(
    params: Namespace,
    save_root: str,
    trial: Optional[Trial]=None,
) -> Tuple[Dict, str]:
    """
    Execute the complete training pipeline for PCG classification model.

    This function orchestrates the entire training process including:
    1. Random seed initialization for reproducibility
    2. Result directory preparation with timestamped naming
    3. Model initialization and optional pretrained weight loading
    4. Dataloader preparation for training and validation sets
    5. Loss function and optimizer configuration
    6. Training execution with validation monitoring
    7. Memory cleanup after training

    Args:
        params: Namespace containing all training parameters including:
            - seed: Random seed for reproducibility
            - host: Hostname for tracking experiments
            - finetune_target: Optional path to pretrained model
            - freeze: Whether to freeze pretrained weights
            - class_weight: Class balancing strategy ("balanced", "auto", or "manual-{val}")
            - device: Device for training (e.g., "cuda:0", "cpu")
            - epochs: Number of training epochs
            - eval_every: Frequency of validation evaluation
            - patience: Early stopping patience
        save_root: Root directory where training results will be saved.
            Results are saved to: save_root/{timestamp}-{hostname}/
        trial: Optional Optuna trial object for hyperparameter optimization.
            If provided, enables pruning of unpromising trials

    Returns:
        tuple: (best_result, save_dir) where:
            - best_result: Dictionary containing best validation metrics:
                * loss: Best validation loss
                * f1score: F1 score at best checkpoint
                * Recall, Precision: Classification metrics
                * AUROC, AUPRC: Area under ROC and precision-recall curves
                * confusion_matrix: Confusion matrix text representation
            - save_dir: Full path to the directory containing saved model
                and training logs
    """
    torch.manual_seed(params.seed)
    np.random.seed(params.seed)

    # Prepare result storing directories
    timestamp = utils.get_timestamp()
    save_setting = f"{timestamp}-{params.host}"
    save_dir = os.path.join(
        save_root, 
        save_setting
    )

    # Trainer prep
    trainer = ModelTrainer(params, save_dir)
    trainer.set_trial(trial)
    trainer.set_model()
    if params.finetune_target is not None:
        weight_file = os.path.join(
            params.finetune_target, "net.pth")
        if not hasattr(params, "freeze"):
            params.freeze = False
        trainer.set_pretrained_model(
            weight_file, params.freeze)
 
    print("Preparing dataloader ...")
    train_loader = trainer.prepare_dataloader(
        data_split="train",
        is_train=True,
    )
    valid_loader = trainer.prepare_dataloader(
        data_split="val",
        is_train=False,
    )

    if params.class_weight == "balanced":
        trainer.set_lossfunc(
            Counter(train_loader.dataset.labels))
    
    trainer.set_optimizer()
    trainer.save_params()

    print("Starting training ...")
    trainer.run(train_loader, valid_loader)
    _, best_result = trainer.get_best_loss()

    del trainer
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # Return best validation loss when executing hyperparameter search.
    return best_result, save_dir

if __name__ == "__main__":

    pass
