# train_utils.py
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import optuna
import torch.nn.functional as F
import copy
from src.models.CASMIR_V1 import get_density_range_index
from src.training.losses import WeightedL1Loss, WeightedMSELoss
from src.models.basic_models import MLP, MLPadv
from src.training.losses import RankSimLoss
from src.training.losses import ConRLoss


def train_pytorch_model(model, X_train, y_train, X_val, y_val, criterion, optimizer, config, device, trial=None):
    """Train PyTorch model with FDS support."""
    from torch.utils.data import TensorDataset, DataLoader
    import torch.nn.functional as F
    
    model.to(device)
    
    epochs = config.get('epochs', 100)
    patience = config.get('patience', 10)
    batch_size = config.get('batch_size', 64)
    num_workers = config.get('num_workers', 0)

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1) if y_train.ndim == 1 else torch.tensor(y_train, dtype=torch.float32)
    train_indices = torch.arange(len(X_train), dtype=torch.long)

    X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1) if y_val.ndim == 1 else torch.tensor(y_val, dtype=torch.float32)

    generator = torch.Generator()
    generator.manual_seed(config.get("random_state", 42))
    
    def worker_init_fn(worker_id):
        import random
        import numpy as np
        worker_seed = torch.initial_seed() % 2**32
        random.seed(worker_seed)
        np.random.seed(worker_seed)
    
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor, train_indices)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        generator=generator,
        worker_init_fn=worker_init_fn
    )

    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        worker_init_fn=worker_init_fn
    )

    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    print(f"Starting PyTorch ({type(model).__name__}) training for {epochs} epochs...")
    for epoch in range(epochs):
        if MLP is not None and isinstance(model, MLP) and getattr(model, 'use_fds', False):
             model.update_fds_epoch()

        model.train()
        train_loss_total = 0.0
        num_train_samples_processed = 0
        
        for batch_X, batch_y, batch_idx in train_loader:
            batch_X, batch_y, batch_idx = batch_X.to(device), batch_y.to(device), batch_idx.to(device) 
            optimizer.zero_grad()

            features = None
            if isinstance(criterion, (RankSimLoss, ConRLoss)):
                outputs, features = model(batch_X, targets=batch_y, epoch=epoch, return_features=True)
            elif (isinstance(model, (MLP, MLPadv)) and getattr(model, 'use_fds', False)):
                outputs, features = model(batch_X, targets=batch_y, epoch=epoch)
            else:
                outputs = model(batch_X, targets=batch_y, epoch=epoch)

            try:
                if isinstance(criterion, ConRLoss):
                    loss = criterion(outputs, batch_y, features=features, weights=None)
                elif isinstance(criterion, RankSimLoss):
                    loss = criterion(outputs, batch_y, features=features)
                elif isinstance(criterion, (WeightedL1Loss, WeightedMSELoss)):
                    loss = criterion(outputs, batch_y, batch_indices=batch_idx) 
                else:
                    loss = criterion(outputs, batch_y)

                if not torch.isfinite(loss):
                    print(f"Warning: Epoch {epoch+1}, invalid loss value ({loss.item()}). Skipping batch.")
                    continue 

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss_total += loss.item() * batch_X.size(0)
                num_train_samples_processed += batch_X.size(0)
            except Exception as batch_err:
                 print(f"Epoch {epoch+1} training batch error: {batch_err}")
                 continue

        avg_train_loss = train_loss_total / num_train_samples_processed if num_train_samples_processed > 0 else 0

        # FDS statistics update
        if isinstance(model, (MLP, MLPadv)) and getattr(model, 'use_fds', False) and epoch >= model.fds_layer.start_update_epoch:
            print(f"Create Epoch [{epoch}] features of all training data...")
            encodings, labels = [], []
            
            with torch.no_grad():
                for batch_X, batch_y, _ in train_loader:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                    outputs, feature = model(batch_X, targets=batch_y, epoch=epoch)
                    encodings.extend(feature.data.squeeze().cpu().numpy())
                    labels.extend(batch_y.data.squeeze().cpu().numpy())

            encodings = torch.from_numpy(np.vstack(encodings)).to(device)
            labels = torch.from_numpy(np.hstack(labels)).to(device)
            
            model.fds_layer.update_last_epoch_stats(epoch)
            model.fds_layer.update_running_stats(encodings, labels, epoch)

        model.eval()
        val_mae_total = 0.0
        num_val_samples = len(val_loader.dataset)
        if num_val_samples == 0:
            avg_val_mae = float('inf')
        else:
            with torch.no_grad():
                for batch_X_val, batch_y_val in val_loader:
                    batch_X_val, batch_y_val = batch_X_val.to(device), batch_y_val.to(device)
                    try:
                        outputs_val = model(batch_X_val, targets=None, epoch=epoch)
                        mae_batch = F.l1_loss(outputs_val, batch_y_val, reduction='sum')
                        val_mae_total += mae_batch.item()
                    except Exception as val_batch_err:
                         print(f"Epoch {epoch+1} validation batch error: {val_batch_err}")
                         continue

            avg_val_mae = val_mae_total / num_val_samples if num_val_samples > 0 else float('inf')

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val MAE: {avg_val_mae:.4f}")

        if avg_val_mae < best_val_mae:
            best_val_mae = avg_val_mae
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping after epoch {epoch+1} with no improvement.")
            break

        if trial:
            try:
                trial.report(avg_val_mae, epoch)
                if trial.should_prune():
                    print(f"Trial pruned at epoch {epoch+1} due to poor performance (MAE: {avg_val_mae:.4f})")
                    raise optuna.exceptions.TrialPruned()
            except optuna.exceptions.TrialPruned:
                raise
            except Exception as e:
                print(f"Unexpected error during Optuna pruning: {e}")
                raise optuna.exceptions.TrialPruned()

    print(f"{type(model).__name__} training complete. Best Val MAE: {best_val_mae:.4f}")
    if best_model_state:
        model.load_state_dict(best_model_state)
    else:
        print(f"Warning: Best {type(model).__name__} model state not found.")

    return best_val_mae


def train_CASMIR_V1(model, X_train, y_train, train_densities_np, X_val, y_val, optimizer, config, device, trial=None):
    """
    Train CASMIR model with composite loss.

    Args:
        model (CASMIR_V1): CASMIR model instance to train.
        X_train, y_train (np.ndarray): Training data (y_train shape: (n, 1)).
        train_densities_np (np.ndarray): Precomputed training data densities (NumPy array, n_samples,).
        X_val, y_val (np.ndarray): Validation data (y_val shape: (n, 1)).
        optimizer: PyTorch optimizer.
        config (dict): Training configuration (epochs, patience, batch_size, lambda_aux, lambda_load, density_boundaries, etc.).
        device: Device to use ('cuda' or 'cpu').
        trial (optuna.trial.Trial, optional): Optuna trial object for pruning during HPO.

    Returns:
        float: Best validation MAE (for Optuna objective function).
    """
    model.to(device)

    epochs = config.get('epochs', 100)
    patience = config.get('patience', 10)
    batch_size = config.get('batch_size', 64)
    lambda_aux = config.get('lambda_aux', 0.5)
    lambda_load = config.get('lambda_load', 0.05)
    density_boundaries = config.get('density_boundaries', None)

    if density_boundaries is None:
         print("Error: density_boundaries is None at CASMIR training start.")
         return float('inf')
    density_low_threshold = density_boundaries['low']
    density_high_threshold = density_boundaries['high']

    try:
        if train_densities_np is None:
             raise ValueError("train_densities_np is None.")
        train_densities_tensor = torch.tensor(train_densities_np, dtype=torch.float32).unsqueeze(1).to(device)

        train_dataset = TensorDataset(
            torch.tensor(X_train, dtype=torch.float32).to(device),
            torch.tensor(y_train, dtype=torch.float32).to(device),
            train_densities_tensor
        )
        val_dataset = TensorDataset(
            torch.tensor(X_val, dtype=torch.float32).to(device),
            torch.tensor(y_val, dtype=torch.float32).to(device)
        )
    except TypeError as e:
        print(f"CASMIR dataset creation error: {e}")
        print(f"X_train type: {type(X_train)}, y_train type: {type(y_train)}, densities type: {type(train_densities_np)}")
        raise e
    except Exception as e:
         print(f"Exception during CASMIR DataLoader creation: {e}")
         raise e

    generator = torch.Generator()
    generator.manual_seed(config.get("random_state", 42))
    
    def worker_init_fn(worker_id):
        import random
        import numpy as np
        worker_seed = torch.initial_seed() % 2**32
        random.seed(worker_seed)
        np.random.seed(worker_seed)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        drop_last=True,
        generator=generator,
        worker_init_fn=worker_init_fn
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        drop_last=False,
        worker_init_fn=worker_init_fn
    )

    regression_loss_fn = nn.L1Loss()
    auxiliary_loss_fn = nn.CrossEntropyLoss()

    best_val_mae = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    print(f"Starting CASMIR training for {epochs} epochs...")
    
    for epoch in range(epochs):
        model.train()
        epoch_train_loss = 0.0
        epoch_reg_loss = 0.0
        epoch_aux_loss = 0.0
        epoch_load_loss = 0.0
        num_train_processed = 0

        for x_batch, y_batch, density_batch in train_loader:
            optimizer.zero_grad()

            try:
                y_pred, gate_weights, gate_logits = model(x_batch, y=y_batch, density=density_batch.squeeze(), apply_smoothing=True)

                loss_reg = regression_loss_fn(y_pred, y_batch)

                density_batch_np = density_batch.squeeze().cpu().numpy()
                target_range_indices = [get_density_range_index(d_val, density_low_threshold, density_high_threshold) for d_val in density_batch_np]
                target_range_indices = torch.LongTensor(target_range_indices).to(device)
                loss_aux = auxiliary_loss_fn(gate_logits, target_range_indices)

                summed_weights_per_expert = gate_weights.sum(dim=0)
                loss_load = torch.var(summed_weights_per_expert)

                total_loss = loss_reg + lambda_aux * loss_aux + lambda_load * loss_load

                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                batch_size_actual = x_batch.size(0)
                epoch_train_loss += total_loss.item() * batch_size_actual
                epoch_reg_loss += loss_reg.item() * batch_size_actual
                epoch_aux_loss += loss_aux.item() * batch_size_actual
                epoch_load_loss += loss_load.item() * batch_size_actual
                num_train_processed += batch_size_actual

            except Exception as batch_err:
                 print(f"Epoch {epoch+1} training batch error: {batch_err}")
                 continue

        if num_train_processed == 0:
             print(f"Warning: No training data processed. Batch size ({batch_size}) may be larger than dataset size ({len(train_dataset)}).")
             avg_train_loss, avg_reg_loss, avg_aux_loss, avg_load_loss = 0, 0, 0, 0
             continue

        avg_train_loss = epoch_train_loss / num_train_processed
        avg_reg_loss = epoch_reg_loss / num_train_processed
        avg_aux_loss = epoch_aux_loss / num_train_processed
        avg_load_loss = epoch_load_loss / num_train_processed

        model.eval()
        val_pred_list = []
        val_mae_total = 0.0
        num_val_samples = len(val_loader.dataset)
        if num_val_samples == 0:
            print("Warning: Validation dataset is empty.")
            avg_val_mae = float('inf')
        else:
            with torch.no_grad():
                for x_val_batch, y_val_batch in val_loader:
                    try:
                        batch_preds = model(x_val_batch, apply_smoothing=False)
                        val_mae_total += regression_loss_fn(batch_preds, y_val_batch).item() * x_val_batch.size(0)
                        val_pred_list.append(batch_preds)
                    except Exception as val_batch_err:
                        print(f"Epoch {epoch+1} validation batch error: {val_batch_err}")
                        continue
            val_preds = torch.cat(val_pred_list, dim=0)
            avg_val_mae = val_mae_total / num_val_samples if num_val_samples > 0 else float('inf')

        print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {avg_train_loss:.4f} "
              f"(Reg: {avg_reg_loss:.4f}, Aux: {avg_aux_loss:.4f}, Load: {avg_load_loss:.4f}) | "
              f"Val MAE: {avg_val_mae:.4f}")

        if avg_val_mae < best_val_mae:
            best_val_mae = avg_val_mae
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping after epoch {epoch+1} with no improvement.")
            break

        if trial:
            try:
                trial.report(avg_val_mae, epoch)
                if trial.should_prune():
                    print(f"Trial pruned at epoch {epoch+1} due to poor performance (MAE: {avg_val_mae:.4f})")
                    raise optuna.exceptions.TrialPruned()
            except optuna.exceptions.TrialPruned:
                raise
            except Exception as e:
                print(f"Unexpected error during Optuna pruning: {e}")
                raise optuna.exceptions.TrialPruned()

    print(f"CASMIR training complete. Best Val MAE: {best_val_mae:.4f}")

    if best_model_state:
        try:
            model.load_state_dict(best_model_state)
            print("Best model state loaded.")
        except Exception as load_err:
             print(f"Error: Failed to load best model state - {load_err}. Using last state.")
    else:
         print("Warning: Best model state not found. Using last model state.")

    if best_val_mae == float('inf'):
         print("Warning: No valid validation MAE obtained.")

    return best_val_mae, val_preds
