import typing as ty

import numpy as np
import scipy.special
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, mean_squared_error, r2_score
import torch
import torchmetrics
from . import util

def evaluate_model(model, X_val, y_val, X_test, y_test, is_classification=True, is_binary=True):

    metrics = {}
    param_count = 0
    if hasattr(model, 'coef_'):
        param_count += np.prod(model.coef_.shape)
    if hasattr(model, 'intercept_'):
        param_count += np.prod(np.array(model.intercept_).shape)


    for part_name, X, y in [('val', X_val, y_val), ('test', X_test, y_test)]:
        try:
            if is_classification:
                y_pred = model.predict(X)
                y_pred_proba = model.predict_proba(X)

                
                if is_binary:
                    y_pred_proba_pos = y_pred_proba[:, 1] if y_pred_proba.shape[1] > 1 else y_pred_proba.ravel()
                    accuracy = accuracy_score(y, y_pred)
                    try:
                        auc = roc_auc_score(y, y_pred_proba_pos)
                    except:
                        auc = 0.0
                    f1 = f1_score(y, y_pred, average='macro')
                    
                    metrics[part_name] = {
                        'score': accuracy, 
                        'accuracy': accuracy,
                        'auc': auc,
                        'f1': f1,
                    }
                else:
                    accuracy = accuracy_score(y, y_pred)
                    try:
                        auc = roc_auc_score(y, y_pred_proba, multi_class='ovr', average='macro')
                    except:
                        auc = 0.0
                    f1_weighted = f1_score(y, y_pred, average='weighted')
                    f1_macro = f1_score(y, y_pred, average='macro')
                    f1_micro = f1_score(y, y_pred, average='micro')
                    
                    metrics[part_name] = {
                        'accuracy': accuracy,
                        'f1_weighted': f1_weighted,
                        'f1_macro': f1_macro,
                        'f1_micro': f1_micro,
                        'auc': auc
                    }
            else:
                y_pred = model.predict(X)
                
                mse = mean_squared_error(y, y_pred)
                r2 = r2_score(y, y_pred)
                
                metrics[part_name] = {
                    'score': -mse,  
                    'mse': mse,
                    'r2': r2
                }
                
        except Exception as e:
            print(f"Error calculating metrics for {part_name}: {e}")
            import traceback
            traceback.print_exc()
            metrics[part_name] = {'score': -999999999.0}
    
    return metrics


def evaluate_model_torch(model, X_val, y_val, X_test, y_test, is_classification, is_binary, device, fast_threshold=4096):
    model.eval()
    metrics = {}
    eval_bs = 1024
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

    for part_name, X, y in [('val', X_val, y_val), ('test', X_test, y_test)]:
        with torch.inference_mode():
            N = X.shape[0]
            if N <= fast_threshold:
                out = model(X)
            else:
                preds = []
                for i in range(0, N, eval_bs):
                    x = X[i:i + eval_bs]
                    y_pred = model(x)
                    if y_pred.ndim == 1:
                        y_pred = y_pred.view(x.shape[0], -1)
                    elif y_pred.ndim >= 2:
                        y_pred = y_pred.view(y_pred.shape[0], -1)
                    preds.append(y_pred)
                out = torch.cat(preds, dim=0)
        predictions = out

        try:
            if is_classification:
                if is_binary:
                    y_pred_proba = torch.sigmoid(predictions).squeeze(-1)
                    y_true = y.int().view(-1)
                    y_pred = (y_pred_proba > 0.5).int()
                    y_true = y_true.squeeze(-1)
                    assert y_true.shape == y_pred.shape

                    accuracy = torchmetrics.classification.BinaryAccuracy().to(device)(y_pred, y_true)
                    auc = torchmetrics.classification.BinaryAUROC().to(device)(y_pred_proba, y_true)
                    f1 = torchmetrics.classification.MulticlassF1Score(num_classes=2, average="macro").to(device)(y_pred, y_true)
                    metrics[part_name] = {
                        'score': accuracy,
                        'accuracy': accuracy,
                        'auc': auc,
                        'f1': f1,
                        'param_count': param_count,
                    }
                else:
                    y_pred_proba = torch.softmax(predictions, dim=1)
                    y_true = torch.tensor(y).to(device).long().view(-1)
                    y_pred = torch.argmax(y_pred_proba, dim=1)

                    accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=model.head.out_features).to(device)(y_pred, y_true)
                    auc = torchmetrics.classification.MulticlassAUROC(num_classes=model.head.out_features).to(device)(y_pred_proba, y_true)
                    f1_macro = torchmetrics.classification.MulticlassF1Score(num_classes=model.head.out_features, average="macro").to(device)(y_pred, y_true)
                    f1_micro = torchmetrics.classification.MulticlassF1Score(num_classes=model.head.out_features, average="micro").to(device)(y_pred, y_true)
                    f1_weighted = torchmetrics.classification.MulticlassF1Score(num_classes=model.head.out_features, average="weighted").to(device)(y_pred, y_true)
                    metrics[part_name] = {
                        'score': accuracy,
                        'accuracy': accuracy,
                        'f1_micro': f1_micro,
                        'f1_macro': f1_macro,
                        'f1_weighted': f1_weighted,
                        'auc': auc,
                        'param_count': param_count,
                    }
            else:
                mse = torchmetrics.MeanSquaredError().to(device)(predictions, y)
                r2 = torchmetrics.R2Score().to(device)(predictions, y)
                metrics[part_name] = {
                    'score': -mse,
                    'mse': mse,
                    'r2': r2
                }
        except Exception as e:
            print(f"Error calculating metrics for {part_name}: {e}")
    return metrics


def compute_val_loss(
    model, X_val, y_val, loss_fn, is_binary, is_classification, num_classes, device, eval_bs=1024
):
    model.eval()
    val_loss = 0.0
    val_batches = 0
    with torch.no_grad():
        N = X_val.shape[0]
        for i in range(0, N, eval_bs):
            batch_X = X_val[i:i+eval_bs].to(device)
            batch_y = y_val[i:i+eval_bs].to(device)
            if is_binary:
                batch_y = batch_y.squeeze(-1).float()
            elif is_classification and num_classes > 2:
                batch_y = batch_y.view(-1).long()
            pred = model(batch_X)
            
            batch_loss = loss_fn(pred, batch_y)
            val_loss += batch_loss.item()
            val_batches += 1
    val_loss /= val_batches
    return val_loss
