import torch
import torch.nn.functional as F
import numpy as np


def _default_2d_array(array):
    return array.reshape(-1, array.shape[-1])


def _default_2d_func(func):
    def wrapper(preds, targets):
        return func(_default_2d_array(preds), _default_2d_array(targets))
    return wrapper

def apply_metric(metric_fn, preds, targets):
    """
    Apply a metric function to the predictions and targets.

    Args:
        metric_fn: The metric function to apply.
        preds: The predicted values.
        targets: The ground truth values.

    Returns:
        The computed metric value.
    """
    try:
        return metric_fn(preds, targets).item()
    except Exception as e:
        print(f"Warning: Failed applying metric {metric_fn.__name__}: {e}")
        return float('nan')


@_default_2d_func
def r2_score(preds, targets):
    r2_scores = []
    num_neurons = targets.shape[1]
    for i in range(num_neurons):
        y_true = targets[:, i]
        y_pred = preds[:, i]
        y_mean = torch.mean(y_true)
        ss_tot = torch.sum((y_true - y_mean) ** 2)
        ss_res = torch.sum((y_true - y_pred) ** 2)
        r2 = 1 - ss_res / (ss_tot + 1e-9)
        r2_scores.append(r2)
    return torch.median(torch.tensor(r2_scores))

def get_avg_data(data, i=0):
    timepoints = data[0].shape[0]

    avg = []
    for j in range(timepoints):
        spike = data[i][j]

        avg.append(torch.mean(spike).item()) 

    return torch.tensor(avg, dtype=torch.float32)

def average_r2_score(preds, targets):
    n_trials = preds.shape[0]

    r2_scores = []

    pred = []
    target = []
    for i in range(n_trials):
        p = get_avg_data(preds, i)
        t = get_avg_data(targets, i)

        pred.append(p)
        target.append(t)

    pred = torch.cat(pred)
    target = torch.cat(target)

    target_mean = torch.mean(target, dim=0)
    ss_tot = torch.sum((target - target_mean) ** 2, dim=0)
    ss_res = torch.sum((target - pred) ** 2, dim=0)
 
    return torch.mean(1 - ss_res / (ss_tot + 1e-9))

@_default_2d_func
def pseudo_r2_score(preds, targets):
    eps = 1e-10
    pseudo_r2_list = []
    num_neurons = preds.shape[1]
    
    for i in range(num_neurons):
        t = targets[:, i]
        p = preds[:, i]
        
        # Compute the log-likelihood for the predictions for this neuron
        log_likelihood = torch.sum(t * torch.log(p + eps) - p)
        log_likelihood_p = torch.sum(t * torch.log(t + eps) - p)
        
        # Compute the log-likelihood for a null model, where we predict the mean of t
        target_mean = torch.mean(t)
        log_likelihood_null = torch.sum(t * torch.log(target_mean + eps) - target_mean)
        
        # Calculate pseudo R² for this neuron
        pseudo_r2 = 1 - (log_likelihood - log_likelihood_p) / (log_likelihood_null - log_likelihood_p)
        pseudo_r2_list.append(pseudo_r2)
    
    # Aggregate the per-neuron pseudo R² scores
    pseudo_r2_tensor = torch.stack(pseudo_r2_list)
    return torch.median(pseudo_r2_tensor)

@_default_2d_func
def linear_regression(preds, targets):
    preds_1 = F.pad(preds, (0, 1), value=1)
    W = preds_1.pinverse() @ targets
    return preds_1 @ W


@_default_2d_func
def regression_r2_score(preds, targets):
    projs = linear_regression(preds, targets)
    return torch.clamp_min(r2_score(projs, targets), -10)
