import numpy as np
import torch
from scipy.stats import spearmanr
from torch.nn import functional as F
from sklearn.ensemble import RandomForestRegressor
import joblib


import numpy as np
import torch
import torch.nn.functional as F
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from src.iees_utils import generate_pfams, calculate_consistency_index, compute_iees_score


def extract_proxy_features(model, dataloader, device, weights):
    """
    Extract low-cost features and corresponding IEES scores for each exit.

    Parameters:
    - model: early-exit DNN with forward_until_exit and .exit_gradients
    - dataloader: torch DataLoader (validation split)
    - device: 'cuda' or 'cpu'
    - weights: list [λ1, λ2, λ3] for computing IEES

    Returns:
    - features: np.array of shape (N_samples * N_exits, 5)
    - targets: np.array of corresponding IEES scores
    """

    features = []
    targets = []
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        for i in range(inputs.size(0)):
            input_image = inputs[i:i + 1]
            pfams_list = []
            cumulative_maps = []
            for exit_idx in range(model.num_exits):
                logits, class_idx, target_score = model.forward_to_exit(input_image, exit_idx,True)
                output = F.softmax(logits, dim=1)
                pfams_list, cumulative_map = generate_pfams(model.activations, model.gradients, input_image,
                                                            pfams_list)
                cumulative_maps.append(cumulative_map)
                consistency_index = calculate_consistency_index(cumulative_maps)
                iees_score, confidence, gwtedAct, activation_score, gradient_score, n_progressive_score = \
                    compute_iees_score(model.activations, model.gradients, output, consistency_index,weights)

                targets.append(iees_score)

                conf = output.max(dim=1).values.detach().cpu().numpy()
                topk = torch.topk(output, 2, dim=1).values
                margin = (topk[:, 0] - topk[:, 1]).detach().cpu().numpy()
                entropy = (-output * torch.log(output + 1e-8)).sum(dim=1).detach().cpu().numpy()

                act_mean = model.activations.mean(dim=(1, 2, 3)).detach().cpu().numpy()
                act_max = model.activations.amax(dim=(1, 2, 3)).detach().cpu().numpy()
                feature_vector = [conf, margin, entropy, act_mean, act_max]
                features.append(feature_vector)

    return np.array(features), np.array(targets)


def train_proxy_rf(features, targets, save_path=None):
    rf_model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    rf_model.fit(features, targets)
    if save_path:
        joblib.dump(rf_model, save_path)
    return rf_model

def train_proxy(model, val_loader,test_loader,weights):
    X_val, y_val = extract_proxy_features(model, val_loader, device='cuda', weights=weights)
    regressor = RandomForestRegressor(n_estimators=100, random_state=42)

    if X_val.ndim == 3 and X_val.shape[2] == 1:
        X_val = X_val.squeeze(-1)

    print("Training RF on validation features...")
    regressor.fit(X_val, y_val)

    # Step 3: Test proxy on test set
    print("Evaluating RF on test features...")
    X_test, y_test = extract_proxy_features(model, test_loader, device='cuda', weights=weights)

    if X_test.ndim == 3 and X_test.shape[2] == 1:
        X_test = X_test.squeeze(-1)

    y_pred = regressor.predict(X_test)

    # Step 4: Metrics
    print("R² Score: ", r2_score(y_test, y_pred))
    print("MAE: ", mean_absolute_error(y_test, y_pred))
    print("RMSE: ", np.sqrt(mean_squared_error(y_test, y_pred)))
    rho, _ = spearmanr(y_test, y_pred)
    print("Spearman ρ:", rho)
    return regressor

def inference_with_proxy(model, rf_proxies, thresholds, x, device):
    x = x.to(device)
    with torch.no_grad():
        for i, exit_head in enumerate(model.exits):
            _, logits, activations = model.forward_until_exit(x, i)
            softmax_logits = F.softmax(logits, dim=1)
            conf = softmax_logits.max().item()
            topk = torch.topk(softmax_logits, 2).values
            margin = (topk[0] - topk[1]).item()
            entropy = (-softmax_logits * torch.log(softmax_logits + 1e-8)).sum().item()
            act_mean = activations.mean().item()
            act_max = activations.amax().item()

            feat_vec = np.array([[conf, margin, entropy, act_mean, act_max]])
            pred_iees = rf_proxies[i].predict(feat_vec)[0]

            if pred_iees >= thresholds[i]:
                return logits  # Early exit decision

    return model.forward_full(x)  # Final exit fallback


# Placeholder: Define your actual IEES computation logic using PFAM
def compute_iees_for_sample(x, exit_index):
    # Replace this with real IEES calculation logic from your PFAM pipeline
    return 0.0  # Dummy score for demonstration
