import numpy as np
import torch as th
from sklearn.metrics import accuracy_score, mean_absolute_error, mean_squared_error
import scipy

def target_alignment(kernel, X, Y):
    """Kernel-target alignment between kernel and labels."""
    K = kernel(X)
    # print(Y); assert False
    if  len(th.unique(Y)) > 2:
      y = Y.view(-1, 1)
      T = (y == y.T).float()
    else: T = th.outer(Y, Y)
    inner_product = th.sum(K * T)
    norm = th.sqrt(th.sum(K * K) * th.sum(T * T))
    return inner_product / norm

class TrainableSVC(th.nn.Module):
    def __init__(self, n_samples, n_classes):
        super().__init__()
        self.W = th.nn.Parameter(th.randn(n_samples, n_classes) * 0.01)
        self.b = th.nn.Parameter(th.zeros(n_classes))

    def forward(self, K):
        return K @ self.W + self.b  # logits: (n_samples, n_classes)
    
    def fit(self, K, Y, optimizer=None, epochs=100):
        optimizer = optimizer or th.optim.Adam(self.parameters(), lr=0.01)
        for _ in range(epochs):
            loss = th.nn.functional.cross_entropy(self.forward(K), Y)
            loss.backward(); optimizer.step(); optimizer.zero_grad()
        return loss.item()

def cross_entropy(kernel, X, Y):
    """
    kernel_fn: a differentiable function returning K(X): (n, n)
    X: input data
    Y: class labels (LongTensor, shape: (n,))
    """
    n_classes = Y.max().item() + 1
    K = kernel(X)
    model = TrainableSVC(K.size(0), n_classes)
    model.fit(K.detach(), Y)
    loss = th.nn.functional.cross_entropy(model(K), Y)
    accuracy = accuracy_score(Y, model(K).argmax(dim=1))
    # print(f"Final accuracy: {accuracy}")
    return loss


evaluate = lambda model, X, y: { f'Test/{key}': metric(X, y) 
  for key, metric in {
    'Accuracy': lambda X,y: accuracy_score(model.predict(X), y),
    'L1': lambda X,y: mean_absolute_error(model.predict(X), y),
    'L2': lambda X,y: mean_squared_error(model.predict(X), y),
    **({'KTA': lambda X,y: target_alignment(model.kernel, X, y)} if callable(model.kernel) else {}),
  }.items()
}


def CI(data, confidence=0.95):
  return range(data.shape[1]), *np.clip(scipy.stats.t.interval(
     confidence=confidence, 
     df=data.shape[0], 
     loc=np.mean(data, axis=0), 
     scale=scipy.stats.sem(data, axis=0)
  ),0,1)
