import optuna
from sklearn import metrics
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import pandas as pd

from resreg.resampler import pdf_relevance, sigmoid_relevance, smoter
from src.data.datasets import calculate_shot_wise_mae
from src.models.basic_models import MLP, MLPadv, SimpleThreeMLPEnsemble, XGBoostWrapper, LightGBMWrapper, CatBoostWrapper
from src.models.CASMIR_V1 import CASMIR_V1
from src.training.losses import GAILossMD, RankSimLoss, WeightedL1Loss, WeightedMSELoss
from src.training.losses import BMCLossMD
from src.training.losses import ConRLoss
from src.training.train_utils import train_pytorch_model, train_CASMIR_V1
from src.models.samplers import apply_smoter, apply_gaussian_noise
from src.models.fds_layer import FDSLayer


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _get_pytorch_optimizer(trial, model, lr, optimizer_name):
    """Create optimizer with suggested learning rate."""
    if optimizer_name == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    return optimizer


def _get_pytorch_criterion(trial, config, y_train_npy):
    """Create loss function based on algorithm name."""
    algo_name = config['algorithm_name']
    weights_tensor = config.get('calculated_weights_tensor', None) 
    gmm_params = config.get('gmm_params', None)                
    apply_weights = config.get('apply_weights_in_loss', False)  
    base_loss_name = config.get('base_loss_for_reweighting', 'l1') 

    if algo_name in ['MLP', 'MLP_FDS']:
        loss_type = trial.suggest_categorical("base_loss", ["l1", "mse"])
        if loss_type == "l1":
            print(f"Algorithm '{algo_name}': Using L1Loss")
            return nn.L1Loss()
        else:
            print(f"Algorithm '{algo_name}': Using MSELoss")
            return nn.MSELoss()
            
    elif algo_name == "MLP_SQRT_INV" or algo_name == "MLP_LDS_Notebook":
        if apply_weights and weights_tensor is not None:
            print(f"Algorithm '{algo_name}': Using Weighted{base_loss_name.upper()}Loss with precomputed weights")
            if algo_name == "MLP_LDS_Notebook":
                lds_ks = trial.suggest_int("lds_ks", 3, config.get('lds_ks_default', 5) * 2 + 1, step=2) 
                lds_sigma = trial.suggest_float("lds_sigma", 0.5, config.get('lds_sigma_default', 2) * 2, log=True)
                print(f"  (LDS HPO Params: ks={lds_ks}, sigma={lds_sigma:.2f})")

            if base_loss_name == "l1":
                return WeightedL1Loss(weights_tensor.cpu()) 
            else:
                return WeightedMSELoss(weights_tensor.cpu())
        else:
            print(f"Warning: No weights for {algo_name}, using standard {base_loss_name.upper()}Loss.")
            return nn.L1Loss() if base_loss_name == "l1" else nn.MSELoss()

    elif algo_name == "MLP_GAI_BMSE":
        if gmm_params is None:
            print(f"Error: No GMM params for {algo_name}, using L1Loss.")
            return nn.L1Loss()

        init_noise_sigma = trial.suggest_float(
            "gai_init_noise_sigma", 
            config.get("gai_init_noise_sigma_default", 6.0) / 3, 
            config.get("gai_init_noise_sigma_default", 6.0) * 3, 
            log=True
        )
        print(f"Algorithm '{algo_name}': Using GAILossMD (init_sigma={init_noise_sigma:.4f})")
        return GAILossMD(gmm_dict=gmm_params, init_noise_sigma=init_noise_sigma)

    elif algo_name == "MLP_BMC_BMSE":
        init_noise_sigma = trial.suggest_float(
            "bmc_init_noise_sigma", 
            config.get("bmse_noise_sigma", 6.0) / 3, 
            config.get("bmse_noise_sigma", 6.0) * 3, 
            log=True
        )
        print(f"Algorithm '{algo_name}': Using BMCLossMD (init_sigma={init_noise_sigma:.4f})")
        return BMCLossMD(init_noise_sigma=init_noise_sigma)

    elif algo_name == 'MLP_BalancedMSE_DEPRECATED':
        gamma = trial.suggest_float("loss_gamma", 0.1, 2.0)
        loss_bins = trial.suggest_int("loss_bins_balanced", 10, 50)
        print(f"Algorithm '{algo_name}': Using BalancedMSELoss (gamma={gamma:.2f}, bins={loss_bins})")
        return BalancedMSELoss(y_train_npy, n_bins=loss_bins, gamma=gamma)
        
    elif algo_name in ['MLP_LDS_Original', 'MLP_LDS_FDS_Original']:
        loss_bins = trial.suggest_int("loss_bins_lds", 30, 100)
        ks = trial.suggest_int("loss_ks", 3, 9, step=2)
        sigma = trial.suggest_float("loss_sigma", 0.5, 5.0, log=True)
        print(f"Algorithm '{algo_name}': Using LDSLoss (ks={ks}, sigma={sigma:.2f}, bins={loss_bins})")
        return LDSLoss(y_train_npy, n_bins=loss_bins, ks=ks, sigma=sigma)
        
    elif algo_name == 'MLP_RankSim':
        lambda_val = trial.suggest_float("ranksim_lambda_val", 0.1, 10.0, log=True)
        alpha = trial.suggest_float("ranksim_alpha", 0.1, 10.0, log=True)
        print(f"Algorithm '{algo_name}': Using RankSimLoss (lambda={lambda_val:.3f}, alpha={alpha:.3f})")
        return RankSimLoss(lambda_val=lambda_val, alpha=alpha)
         
    elif algo_name == 'MLP_ConR':
        w = trial.suggest_float("conr_distance_threshold", 0.1, 5.0)
        t = trial.suggest_float("conr_temperature", 0.01, 1.0, log=True)
        e = trial.suggest_float("conr_pushing_power", 0.001, 0.1, log=True)
        alpha = trial.suggest_float("conr_alpha", 0.1, 10.0, log=True)
        mse_weight = trial.suggest_float("conr_mse_weight", 0.1, 10.0, log=True)
        print(f"Algorithm '{algo_name}': Using ConRLoss (w={w:.3f}, t={t:.4f}, e={e:.4f})")
        return ConRLoss(w=w, t=t, e=e, alpha=alpha, mse_weight=mse_weight)
    else:
        print(f"Warning: Unknown algorithm '{algo_name}'. Using L1Loss.")
        return nn.L1Loss()


def objective_MLP(trial, config, X_train, y_train, X_val, y_val, shot_mapping=None):
    """Optuna objective function for PyTorch MLP-based models (including FDS, LDS)."""
    algo_name = config['algorithm_name']

    hidden_dim1 = trial.suggest_categorical("hidden_dim1", [128, 256, 512])
    hidden_dim2 = trial.suggest_categorical("hidden_dim2", [64, 128, 256])
    hidden_dims = [hidden_dim1, hidden_dim2]

    dropout_rate = trial.suggest_float("dropout", 0.1, 0.4)

    train_size = X_train.shape[0]
    available_batch_sizes = [bs for bs in [32, 64, 128, 256] if bs <= train_size]
    if not available_batch_sizes:
        available_batch_sizes = [min(32, train_size)]
    batch_size = trial.suggest_categorical("batch_size", available_batch_sizes)

    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW"])
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)

    use_fds = False
    fds_config = None
    fds_discretizer = config.get('fds_discretizer', None)
    
    if algo_name in ["MLP_FDS", "MLP_LDS_FDS", "MLP_FDS_Notebook", "MLP_LDS_FDS_Notebook"]:
        if FDSLayer is None or fds_discretizer is None:
            print(f"Warning: FDS unavailable. Proceeding without FDS.")
        else:
            use_fds = True
            fds_momentum = trial.suggest_float("fds_momentum", 0.8, 0.99)
            fds_start_epoch = 0

            fds_config = {
                'num_target_bins': fds_discretizer.n_bins_,
                'fds_momentum': fds_momentum,
                'start_update_epoch': fds_start_epoch,
                'kernel': trial.suggest_categorical("fds_kernel", ['gaussian', 'triang', 'laplace']),
                'kernel_size': trial.suggest_int("fds_kernel_size", 3, 9, step=2),
                'kernel_sigma': trial.suggest_float("fds_kernel_sigma", 1.0, 4.0),
                'start_smooth_epoch': trial.suggest_int("fds_start_smooth_epoch", 1, 3)
            }
            print(f"FDS enabled: config={fds_config}")

    if config['algorithm_name'] in ["MLP_SQRT_INV", "MLP_LDS_Notebook", "MLP_LDS_FDS_Notebook"] and config.get('calculated_weights_tensor') is None:
        print(f"Warning: {config['algorithm_name']} running but calculated_weights_tensor not in config.")
    
    if config['algorithm_name'] == "MLP_FDS_Notebook":
        print(f"Info: {config['algorithm_name']} is FDS-only, no weight tensor needed.")
        
    if config['algorithm_name'] == "MLP_GAI_BMSE" and config.get('gmm_params') is None:
        print(f"Warning: {config['algorithm_name']} running but gmm_params not in config.")

    try:
        if algo_name == "Simple_Ensemble":
            model = SimpleThreeMLPEnsemble(
                input_dim=config['input_dim'],
                hidden_dims=hidden_dims, output_dim=1, dropout_rate=dropout_rate,
                use_fds=use_fds, fds_config=fds_config, fds_discretizer=fds_discretizer,
                use_residual=True
            ).to(DEVICE)
            print(f"Using Simple 3-MLPs Ensemble (algorithm: {algo_name})")
        else:
            model = MLP(
                input_dim=config['input_dim'],
                hidden_dims=hidden_dims, output_dim=1, dropout_rate=dropout_rate,
                use_fds=use_fds, fds_config=fds_config, fds_discretizer=fds_discretizer,
                use_residual=True
            ).to(DEVICE)
            print(f"Using MLPadv model (algorithm: {algo_name})")
    except Exception as model_init_err:
        model_type = "SimpleEnsemble" if algo_name == "Simple_Ensemble" else "MLPadv"
        print(f"{model_type} model initialization error: {model_init_err}")
        return float('inf')

    optimizer = _get_pytorch_optimizer(trial, model, lr, optimizer_name)
    criterion = _get_pytorch_criterion(trial, config, y_train)

    train_config = {
        'epochs': config.get('hpo_epochs', 50),
        'patience': config.get('hpo_patience', 5),
        'batch_size': batch_size,
        'num_workers': config.get('num_workers', 0),
    }

    print(f"{algo_name} Trial {trial.number}: Params {trial.params}")

    try:
        validation_mae = train_pytorch_model(
            model, X_train, y_train, X_val, y_val, criterion, optimizer, train_config, DEVICE, trial
        )
    except optuna.exceptions.TrialPruned as e:
        raise e
    except Exception as e:
        print(f"{algo_name} Trial {trial.number} training error: {e}")
        import traceback
        traceback.print_exc()
        return float('inf')

    return validation_mae


def objective_tree(trial, hpo_config, X_train, y_train, X_val, y_val, shot_mapping=None):
    """Optuna objective function for tree-based models (XGBoost, LightGBM, CatBoost)."""
    smoter_params = {
        'relevance_method': trial.suggest_categorical('relevance_method', ['sigmoid', 'pdf']),
        'relevance_percentile': trial.suggest_float('relevance_percentile', 60, 90),
        'resampling_strategy': trial.suggest_categorical('resampling_strategy', ['balance', 'average', 'extreme']),
    }
    if smoter_params['relevance_method'] == 'sigmoid':
        smoter_params.update({
            'sigmoid_cl': trial.suggest_float('sigmoid_cl', np.percentile(y_train, 10), np.percentile(y_train, 45)),
            'sigmoid_ch': trial.suggest_float('sigmoid_ch', np.percentile(y_train, 55), np.percentile(y_train, 90))
        })
    else:
        smoter_params.update({
            'pdf_bandwidth': trial.suggest_float('pdf_bandwidth', 0.1, 2.0, log=True)
        })

    if "XGBoost" in hpo_config["algorithm_name"]:
        model_params = {
            'max_depth': trial.suggest_int('max_depth', 3, 10),
            'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
            'n_estimators': trial.suggest_int('n_estimators', 50, 300),
            'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
            'subsample': trial.suggest_float('subsample', 0.6, 1.0),
            'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
            'gamma': trial.suggest_float('gamma', 1e-8, 1.0, log=True),
            'early_stopping_rounds': trial.suggest_int('early_stopping_rounds', 5, 30)
        }
        model = XGBoostWrapper(model_params)
    elif "LightGBM" in hpo_config["algorithm_name"]:
        model_params = {
            'num_leaves': trial.suggest_int('num_leaves', 20, 100),
            'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
            'n_estimators': trial.suggest_int('n_estimators', 50, 300),
            'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
            'subsample': trial.suggest_float('subsample', 0.6, 1.0),
            'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
            'early_stopping_rounds': trial.suggest_int('early_stopping_rounds', 5, 30)
        }
        model = LightGBMWrapper(model_params)
    elif "CatBoost" in hpo_config["algorithm_name"]:
        model_params = {
            'iterations': trial.suggest_int('iterations', 50, 300),
            'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
            'depth': trial.suggest_int('depth', 4, 10),
            'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 1.0, 10.0, log=True),
            'border_count': trial.suggest_int('border_count', 32, 255),
            'subsample': trial.suggest_float('subsample', 0.6, 1.0),
            'early_stopping_rounds': trial.suggest_int('early_stopping_rounds', 5, 30)
        }
        model = CatBoostWrapper(model_params)
    else:
        raise ValueError(f"Unsupported tree model: {hpo_config['algorithm_name']}")

    X_train_resampled, y_train_resampled = X_train, y_train
    sampler_params_log = {}

    try:
        if hpo_config["sampler"] == "SMOTER":
            relevance_method = smoter_params['relevance_method']
            
            if relevance_method == 'sigmoid':
                y_std = np.std(y_train)
                y_mean = np.mean(y_train)
                
                relevance = sigmoid_relevance(y_train, cl=smoter_params['sigmoid_cl'], ch=smoter_params['sigmoid_ch'])
                sampler_params_log.update({
                    'relevance_method': 'sigmoid',
                    'sigmoid_cl': smoter_params['sigmoid_cl'],
                    'sigmoid_ch': smoter_params['sigmoid_ch']
                })
            else:
                bandwidth = smoter_params['pdf_bandwidth']

                relevance = pdf_relevance(y_train, bandwidth=bandwidth)
                sampler_params_log.update({
                    'relevance_method': 'pdf',
                    'pdf_bandwidth': bandwidth
                })

            relevance_percentile = smoter_params['relevance_percentile']
            relevance_threshold = np.percentile(relevance, relevance_percentile)
            
            k_neighbors = trial.suggest_int('sampler_k_neighbors', 3, 10)
            
            resampling_strategies = [
                'balance',
                'average',
                'extreme'
            ]
            strategy = smoter_params['resampling_strategy']
            
            try:
                X_train_resampled, y_train_resampled = smoter(
                    X_train, y_train,
                    relevance=relevance,
                    relevance_threshold=relevance_threshold,
                    k=k_neighbors,
                    over=strategy,
                    random_state=trial.number
                )
                
                sampler_params_log.update({
                    'relevance_threshold': relevance_threshold,
                    'k_neighbors': k_neighbors,
                    'resampling_strategy': strategy,
                    'success': True
                })
                
            except Exception as smoter_err:
                print(f"SMOTER failed: {str(smoter_err)}. Trying fallback strategy...")
                try:
                    X_train_resampled, y_train_resampled = smoter(
                        X_train, y_train,
                        relevance=relevance,
                        relevance_threshold=relevance_threshold,
                        k=k_neighbors,
                        over='balance',
                        random_state=trial.number
                    )
                    sampler_params_log['fallback_to_balance'] = True
                except Exception as fallback_err:
                    print(f"Fallback strategy failed: {str(fallback_err)}. Using original data.")
                    X_train_resampled, y_train_resampled = X_train, y_train
                    sampler_params_log['used_original_data'] = True
        elif hpo_config["sampler"] == "GaussianNoise":
            delta = trial.suggest_float('sampler_delta', 0.01, 0.2, log=True)
            relevance_method = smoter_params['relevance_method']
            
            if relevance_method == 'sigmoid':
                relevance = sigmoid_relevance(y_train, cl=smoter_params['sigmoid_cl'], ch=smoter_params['sigmoid_ch'])
                sampler_params_log.update({
                    'relevance_method': 'sigmoid',
                    'sigmoid_cl': smoter_params['sigmoid_cl'],
                    'sigmoid_ch': smoter_params['sigmoid_ch']
                })
            else:
                bandwidth = smoter_params['pdf_bandwidth']
                relevance = pdf_relevance(y_train, bandwidth=bandwidth)
                sampler_params_log.update({
                    'relevance_method': 'pdf',
                    'pdf_bandwidth': bandwidth
                })
            
            relevance_percentile = smoter_params['relevance_percentile']
            relevance_threshold = np.percentile(relevance, relevance_percentile)
            
            strategy = smoter_params['resampling_strategy']
            X_train_resampled, y_train_resampled = apply_gaussian_noise(
                X_train, y_train,
                relevance=relevance,
                relevance_threshold=relevance_threshold,
                delta=delta,
                over=strategy
            )
            
            sampler_params_log.update({
                'relevance_threshold': relevance_threshold,
                'delta': delta,
                'resampling_strategy': strategy,
                'success': True
            })
            

    except Exception as e:
        print(f"SMOTER setup error: {str(e)}. Using original data.")
        X_train_resampled, y_train_resampled = X_train, y_train
        sampler_params_log['error'] = str(e)
        

    try:
        model.fit(X_train_resampled, y_train_resampled, X_val, y_val)
        y_pred = model.predict(X_val)
        
        val_score = metrics.mean_absolute_error(y_val, y_pred)
        
        trial.set_user_attr('sampler_params', sampler_params_log)
        
        return val_score

    except Exception as e:
        print(f"Model training/evaluation error: {str(e)}")
        raise optuna.exceptions.TrialPruned()

def objective_CASMIR_V1(trial, config, X_train, y_train, X_val, y_val, shot_mapping=None):
    """Optuna objective function for CASMIR model."""
    
    expert_dim1 = trial.suggest_categorical("expert_dim1", [128, 256, 512])
    expert_dim2 = trial.suggest_categorical("expert_dim2", [64, 128, 256])
    expert_hidden = [expert_dim1, expert_dim2]

    gate_dim1 = trial.suggest_categorical("gate_dim1", [64, 128, 256])
    gate_hidden = [gate_dim1]

    num_experts = config.get('casmir_num_experts', 3)

    train_size = X_train.shape[0]
    available_batch_sizes = [bs for bs in [32, 64, 128, 256] if bs <= train_size]
    if not available_batch_sizes:
        available_batch_sizes = [min(32, train_size)]
    hpo_batch_size = trial.suggest_categorical("batch_size", available_batch_sizes)

    k_max = max(5, min(15, hpo_batch_size - 1))
    k_neighbors = trial.suggest_int('k_neighbors', 5, k_max)

    feature_bw = trial.suggest_float('feature_bw', 0.5, 5.0, log=True)
    label_bw = trial.suggest_float('label_bw', 1.0, 50.0, log=True)
    density_factor = trial.suggest_float('density_factor', 0.0, 0.5)
    strength_base = trial.suggest_float('strength_base', 0.1, 0.9)
    density_c = trial.suggest_float('density_c', 1.0, 50.0, log=True)

    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW"])
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    expert_dropout = trial.suggest_float("expert_dropout", 0.1, 0.4)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)

    lambda_aux = trial.suggest_float("lambda_aux", 0.1, 1.0)
    lambda_load = trial.suggest_float("lambda_load", 0.01, 0.2)
    cas_params_trial = {
        'k': k_neighbors,
        'feature_bw': feature_bw,
        'label_bw': label_bw,
        'density_factor': density_factor,
        'strength_base': strength_base,
        'density_c': density_c,
        'epsilon': 1e-6
    }
    model = CASMIR_V1(
        input_dim=config['input_dim'],
        num_experts=num_experts,
        expert_hidden_dims=expert_hidden,
        gate_hidden_dims=gate_hidden,
        cas_params=cas_params_trial,
        expert_dropout=expert_dropout,
        expert_type=config.get('ada_expert_type', 'tabular')
    )

    if optimizer_name == "Adam": 
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    else: 
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    train_config_hpo = {
        'epochs': config.get('hpo_epochs', 300),
        'patience': config.get('hpo_patience', 30),
        'batch_size': hpo_batch_size,
        'lambda_aux': lambda_aux,
        'lambda_load': lambda_load,
        'density_boundaries': config['density_boundaries'],
    }

    train_densities_np = config.get('precomputed_train_densities_np', None)
    if train_densities_np is None:
        print("Error: Precomputed NumPy densities not passed to CASMIR HPO objective.")
        return float('inf')

    print(f"CASMIR Trial {trial.number}: Params {trial.params}")

    try:
        y_train_reshaped = y_train.reshape(-1, 1)
        y_val_reshaped = y_val.reshape(-1, 1)

        validation_mae, y_pred_val = train_CASMIR_V1(
            model=model,
            X_train=X_train,
            y_train=y_train_reshaped,
            train_densities_np=train_densities_np,
            X_val=X_val,
            y_val=y_val_reshaped,
            optimizer=optimizer,
            config=train_config_hpo,
            device=DEVICE,
            trial=trial
        )
    except optuna.exceptions.TrialPruned as e:
        raise e
    except Exception as e:
        print(f"CASMIR Trial {trial.number} training error: {e}")
        import traceback
        traceback.print_exc()
        return float('inf')

    evaluation_results = calculate_shot_wise_mae(
        y_true=y_val_reshaped.flatten(),
        y_pred=y_pred_val.flatten(), 
        shot_mapping=shot_mapping
    )
    
    objective_mae = evaluation_results['overall']

    return objective_mae


def evaluate_shot_wise_performance(y_pred, y_true, config, data_split='validation'):
    """
    Calculate shot-wise MAE performance for validation or test dataset.

    Args:
        y_pred: Model predictions.
        y_true: Ground truth values.
        config: Configuration dictionary with shot information.
        data_split: 'validation' or 'test'.

    Returns:
        Dictionary containing total MAE, shot-wise MAE, and sample counts.
    """
    y_pred_np = y_pred.cpu().numpy()
    total_mae = np.mean(np.abs(y_true - y_pred_np))

    shot_info = config['bal_shot_info'][data_split]['indices']

    shot_wise_mae = {}
    for shot_type in ['few', 'medium', 'many']:
        indices = shot_info[shot_type]
        if indices:
            mae = np.mean(np.abs(y_true[indices] - y_pred_np[indices]))
            shot_wise_mae[shot_type] = mae
        else:
            shot_wise_mae[shot_type] = None

    sample_counts = {
        shot_type: len(indices) 
        for shot_type, indices in shot_info.items()
    }

    return {
        'total_mae': total_mae,
        'shot_wise_mae': shot_wise_mae,
        'sample_counts': sample_counts
    }
