import numpy as np
import sklearn.covariance
import torch
import math
import time
import pandas as pd

from utils import compute_roc_auc_score, correlation_baseline_score, lasso_baseline_score, barik_honorio_score, linear_quadratic_optimization_score
from deep_graph import permute_node_ordering_and_compute_covariance_matrix

def get_correlation_roc_auc(dataset, anticorrelation=False):
    roc_aucs = [correlation_baseline_score(data["X"], data["A"], anticorrelation=anticorrelation) for data in dataset]

    return np.mean(roc_aucs), np.std(roc_aucs) / math.sqrt(len(dataset))

def get_barik_honorio_roc_auc(dataset):
    roc_aucs = [barik_honorio_score(data["X"], data["A"]) for data in dataset]

    return np.mean(roc_aucs), np.std(roc_aucs) / math.sqrt(len(dataset))

def get_linear_quadratic_optimization_roc_auc(dataset, beta, theta1, theta2, smooth):
    roc_aucs = [linear_quadratic_optimization_score(data["X"], data["A"], beta, theta1, theta2, smooth) for data in dataset]

    return np.mean(roc_aucs), np.std(roc_aucs) / math.sqrt(len(dataset))

def get_lasso_roc_auc(dataset, reg=0):
    roc_aucs = [lasso_baseline_score(data["X"], data["A"], reg=reg) for data in dataset]

    return np.mean(roc_aucs), np.std(roc_aucs) / math.sqrt(len(dataset))

def eval_baseline(train_dataset, val_dataset, test_dataset, game_type):
    # Correlation Baseline
    train_correlation_roc_auc_mean, train_correlation_roc_auc_std = get_correlation_roc_auc(train_dataset, anticorrelation=False)
    val_correlation_roc_auc_mean, val_correlation_roc_auc_std = get_correlation_roc_auc(val_dataset, anticorrelation=False)
    test_correlation_roc_auc_mean, test_correlation_roc_auc_std = get_correlation_roc_auc(test_dataset, anticorrelation=False)

    # Anticorrelation Baseline
    train_anticorrelation_roc_auc_mean, train_anticorrelation_roc_auc_std = get_correlation_roc_auc(train_dataset, anticorrelation=True)
    val_anticorrelation_roc_auc_mean, val_anticorrelation_roc_auc_std = get_correlation_roc_auc(val_dataset, anticorrelation=True)
    test_anticorrelation_roc_auc_mean, test_anticorrelation_roc_auc_std = get_correlation_roc_auc(test_dataset, anticorrelation=True)

    # Graphical Lasso Baseline
    print('Tuning Graphical Lasso regularization parameter...', end=' ', flush=True)

    best_reg = 0
    regs = [pow(10, -i) for i in range(5)]
    best_val_roc_auc = 0
    best_val_roc_auc_std = 0.
    for reg in regs:
         val_lasso_roc_auc_mean, val_lasso_roc_auc_std = get_lasso_roc_auc(val_dataset, reg=reg)
         if val_lasso_roc_auc_mean > best_val_roc_auc:
             best_val_roc_auc = val_lasso_roc_auc_mean
             best_val_roc_auc_std = val_lasso_roc_auc_std
             best_reg = reg
    print(f'Done! Best regularization parameter found: {best_reg}')

    train_lasso_roc_auc_mean, train_lasso_roc_auc_std = 0., 0.  # get_lasso_roc_auc(train_dataset, reg=best_reg)
    val_lasso_roc_auc_mean, val_lasso_roc_auc_std = best_val_roc_auc, best_val_roc_auc_std  # get_lasso_roc_auc(val_dataset, reg=best_reg)
    test_lasso_roc_auc_mean, test_lasso_roc_auc_std = get_lasso_roc_auc(test_dataset, reg=best_reg)

    print(f"Correlation baseline     --- train ROC_AUC:{train_correlation_roc_auc_mean:.4f}+-{train_correlation_roc_auc_std:.4f}, val ROC_AUC:{val_correlation_roc_auc_mean:.4f}+-{val_correlation_roc_auc_std:.4f}, test ROC_AUC:{test_correlation_roc_auc_mean:.4f}+-{test_correlation_roc_auc_std:.4f}")
    print(f"Anticorrelation baseline --- train ROC_AUC:{train_anticorrelation_roc_auc_mean:.4f}+-{train_anticorrelation_roc_auc_std:.4f}, val ROC_AUC:{val_anticorrelation_roc_auc_mean:.4f}+-{val_anticorrelation_roc_auc_std:.4f}, test ROC_AUC:{test_anticorrelation_roc_auc_mean:.4f}+-{test_anticorrelation_roc_auc_std:.4f}")
    print(f"Lasso baseline           --- train ROC_AUC:{train_lasso_roc_auc_mean:.4f}+-{train_lasso_roc_auc_std:.4f}, val ROC_AUC:{val_lasso_roc_auc_mean:.4f}+-{val_lasso_roc_auc_std:.4f}, test ROC_AUC:{test_lasso_roc_auc_mean:.4f}+-{test_lasso_roc_auc_std:.4f}")
   
    baseline_results = {
            "correlation_train_roc_auc_mean": train_correlation_roc_auc_mean,
            "correlation_train_roc_auc_std": train_correlation_roc_auc_std,
            "correlation_val_roc_auc_mean": val_correlation_roc_auc_mean, 
            "correlation_val_roc_auc_std": val_correlation_roc_auc_std, 
            "correlation_test_roc_auc_mean": test_correlation_roc_auc_mean,
            "correlation_test_roc_auc_std": test_correlation_roc_auc_std,
            "anticorrelation_train_roc_auc_mean": train_anticorrelation_roc_auc_mean,
            "anticorrelation_train_roc_auc_std": train_anticorrelation_roc_auc_std,
            "anticorrelation_val_roc_auc_mean": val_anticorrelation_roc_auc_mean,
            "anticorrelation_val_roc_auc_std": val_anticorrelation_roc_auc_std,
            "anticorrelation_test_roc_auc_mean": test_anticorrelation_roc_auc_mean,
            "anticorrelation_test_roc_auc_std": test_anticorrelation_roc_auc_std,
            "lasso_train_roc_auc_mean": train_lasso_roc_auc_mean,
            "lasso_train_roc_auc_std": train_lasso_roc_auc_std,
            "lasso_val_roc_auc_mean": val_lasso_roc_auc_mean, 
            "lasso_val_roc_auc_std": val_lasso_roc_auc_std, 
            "lasso_test_roc_auc_mean": test_lasso_roc_auc_mean,
            "lasso_test_roc_auc_std": test_lasso_roc_auc_std,
    }

    if game_type == "barik_honorio":
        test_barik_honorio_roc_auc_mean, test_barik_honorio_roc_auc_std = get_barik_honorio_roc_auc(test_dataset)
        baseline_results["barik_honorio_test_roc_auc_mean"] = test_barik_honorio_roc_auc_mean
        baseline_results["barik_honorio_test_roc_auc_std"] = test_barik_honorio_roc_auc_std

    return baseline_results 

def eval_everything(model, train_eval_loader, val_loader, test_loader, device, args):
    train_roc_auc_mean, train_roc_auc_std  = eval(model=model, data_loader=train_eval_loader, device=device, args=args)
    val_roc_auc_mean, val_roc_auc_std = eval(model=model, data_loader=val_loader, device=device, args=args)
    test_roc_auc_mean, test_roc_auc_std = eval(model=model, data_loader=test_loader, device=device, args=args)

    return train_roc_auc_mean, val_roc_auc_mean, test_roc_auc_mean, train_roc_auc_std, val_roc_auc_std, test_roc_auc_std

def eval(model, data_loader, device, args):
    model.eval()
    roc_aucs = []
    for data in data_loader:
        X, A = data["X"].to(device), data["A"].to(device)

        if args.model_to_train == 'deep_graph':
            A_pred = torch.zeros((X.shape[0], X.shape[1], X.shape[1])).to(device)

            for _ in range(args.num_deep_graph_eval_runs):
                _, C, _, perm = permute_node_ordering_and_compute_covariance_matrix(X, A)
                inv_perm = np.argsort(perm)
                A_pred_ = model(C)

                A_pred_sym = A_pred_# torch.triu(A_pred_, diagonal=1)  # upper triangular matrix
                A_pred_sym = A_pred_sym + torch.transpose(A_pred_sym, 1, 2)
                A_pred_sym = A_pred_sym[:, inv_perm, :][:, :, inv_perm]  # restores original order
                A_pred_sym = A_pred_sym.detach()
                A_pred += A_pred_sym  # sums to original matrix

            A_pred = A_pred / args.num_deep_graph_eval_runs
        else:
            A_pred = model(X)

        roc_aucs.append(compute_roc_auc_score(A, A_pred))

    return np.mean(roc_aucs), np.std(roc_aucs) / math.sqrt(len(data_loader))