from sklearn.metrics import roc_auc_score
import numpy as np
import torch

from encoder import MLPEncoder, TransformerEncoder, MLPOnNodesEncoder, PerGameTransformerEncoder
from decoder import DotProductDecoder, CosineSimilarityDecoder, CorrelationCoefficientDecoder, MLPDecoder
from barik_honorio_model import barik_honorio
from linear_quadratic_optimization import non_smoothing_model, smoothing_model

from inverse_covariance import QuicGraphicalLasso

def compute_roc_auc_score(A_true, A_pred):
    # NOTE: The ROC_AUC score is computed only on off-diagonal elements, since we assume the diagonal always contains zero
    pred = mask_diagonal(A_pred).reshape(-1).cpu().detach().numpy()
    target = mask_diagonal(A_true).reshape(-1).cpu().detach().numpy()

    return roc_auc_score(target, pred)

def mask_diagonal(A):
    mask = ~torch.eye(A.shape[1], dtype=bool, device=A.device).repeat(A.shape[0], 1, 1)

    return A.masked_select(mask)

def mask_diagonal_and_sigmoid(A):
    mask = ~torch.eye(A.shape[1], dtype=bool, device=A.device).repeat(A.shape[0], 1, 1)
    return torch.sigmoid(A)*mask
    #return A.masked_select(mask).view(A.shape[0], A.shape[1], A.shape[2])

def zero_out_diagonal(A):
    for i in range(A.shape[0]):
        A[i].fill_diagonal_(0)

def correlation_baseline_score(x, A_true, anticorrelation=False):
    multiplier = -1 if anticorrelation else 1
    A_pred = torch.Tensor(multiplier * np.corrcoef(x.cpu().numpy()))
    
    return compute_roc_auc_score(A_true.unsqueeze(0), A_pred.unsqueeze(0))

def lasso_baseline_score(X, A_true, reg):
    model = QuicGraphicalLasso(lam=reg, init_method='cov')
    model.fit(X.T.cpu().numpy())
    A_pred = torch.Tensor(model.covariance_)
    
    return compute_roc_auc_score(A_true.unsqueeze(0), A_pred.unsqueeze(0))

def barik_honorio_score(X, A_true):
    pred = barik_honorio(X.cpu().numpy())
    A_pred = torch.Tensor(pred)
    
    return compute_roc_auc_score(A_true.unsqueeze(0), A_pred.unsqueeze(0))

def linear_quadratic_optimization_score(X, A_true, beta, theta1, theta2, smooth):
    if smooth:
        A_pred = torch.Tensor(smoothing_model(X.cpu().numpy(), beta, theta1, theta2))
    else:
        A_pred = torch.Tensor(non_smoothing_model(X.cpu().numpy(), beta, theta1, theta2))
    
    return compute_roc_auc_score(A_true.unsqueeze(0), A_pred.unsqueeze(0))

def permute_features(X, B):
    p = np.random.permutation(X.shape[-1])
    X = X[:, :, p]
    B = B[:, :, p]

    return X, B

def get_encoder(encoder_type, n_nodes, n_games, hidden_dim, dropout, transformer_num_layers, transformer_feedforward_dim):
    if encoder_type == "mlp_on_seq":
        return MLPEncoder(n_nodes, n_games, hidden_dim, dropout=dropout)
    elif encoder_type == "mlp_on_nodes":
        return MLPOnNodesEncoder(n_nodes, n_games, hidden_dim, dropout=dropout)
    elif encoder_type == "transformer":
        return TransformerEncoder(n_nodes, n_games, hidden_dim, dropout=dropout, num_layers=transformer_num_layers, transformer_feedforward_dim=transformer_feedforward_dim)
    elif encoder_type == "column_transformer":
        return ColumnTransformerEncoder(n_nodes, n_games, hidden_dim, dropout=dropout, num_layers=transformer_num_layers, transformer_feedforward_dim=transformer_feedforward_dim)
    elif encoder_type == "per_game_transformer":
        return PerGameTransformerEncoder(n_nodes, n_games, hidden_dim, dropout=dropout, num_layers=transformer_num_layers)
    else:
        return NotImplementedError

def get_decoder(decoder_type, in_channels, permutation_invariant):
    if decoder_type == "dot_product":
        return DotProductDecoder()
    elif decoder_type == "cosine_similarity":
        return CosineSimilarityDecoder()
    elif decoder_type == "correlation_coefficient":
        return CorrelationCoefficientDecoder()
    elif decoder_type == "mlp":
        return MLPDecoder(in_channels, permutation_invariant=permutation_invariant)
    else:
        return NotImplementedError

def get_loss_weights(use_weighted_loss, train_loader):
    if use_weighted_loss:
        #assert args.loss == "bce"
        print('Using weighted BCE.')

        num_edges, num_possible_edges = 0, 0
        for data in train_loader:
            A = data["A"]
            num_edges += A.flatten().sum()
            num_possible_edges += A.flatten().shape[0]

        pos_weight = (num_possible_edges - num_edges) / num_edges  # no. negatives / no. positives
    else:
        pos_weight = torch.ones((1,), dtype=float)

    return pos_weight
