import torch
from evaluate_sensors import get_val_performance_data, get_full_err_scores
import numpy as np
import tqdm
from scipy.stats import iqr
from torch_geometric.data import DataLoader as GraphDataLoader
def get_score(test_result, val_result, anomaly_scores = None):

    feature_num = len(test_result[0][0])
    np_test_result = np.array(test_result)
    np_val_result = np.array(val_result)
    print(np_test_result.shape)
    test_labels = np_test_result[2, :, 0].tolist()

    if anomaly_scores is None:
        test_scores, normal_scores = get_full_err_scores(test_result, val_result)
        print(test_scores.shape, normal_scores.shape)
    else:
        test_scores, normal_scores = anomaly_scores['test'],anomaly_scores['normal']
    print(len(test_scores),len(test_labels))
    #top1_best_info = get_best_performance_data(test_scores, test_labels, topk=1) 
    top1_val_info = get_val_performance_data(test_scores, normal_scores, test_labels, topk=1)

    from sklearn.metrics import f1_score
    #print(len(test_labels))
    #th = normal_scores.max()
    #print(test_scores.T.max(axis=-1)>th)
    #print("F1", f1_score(test_labels,test_scores.T.max(axis=-1)>th))
    print('=========================** Result **============================\n')

    info = None
    
    info = top1_val_info
    #elif self.env_config['report'] == 'val':
    #    info = top1_val_info

    
    metrics = ['F1',"precision",'recall','AUC']
    results_dict = {m:info[i] for i,m in enumerate(metrics)}
    return results_dict

def score_model(model,test_dataloader,val_dataloader,args, labels):
    import os
    path = os.path.join(args.model_save_dir,args.model_type+f"_{args.n_masks}.pth")
    model.load_model(path)
    best_model = model
    device = 'cpu'
    print(best_model)
    test_loss, test_result = test(best_model, test_dataloader, device, labels)
    val_loss, val_result = test(best_model, val_dataloader, device, labels=None)

    test_per_entry_loss = (np.array(test_result[0]) - np.array(test_result[1]))**2
    val_per_entry_loss = (np.array(val_result[0]) - np.array(val_result[1]))**2


    
    results_dict=get_score(test_result, val_result)
    
    results_dict['val_loss']=val_loss
    results_dict['test_loss']=test_loss


    for k,v in results_dict.items():
        print(k,":",v)

    return results_dict

def concat_dict(x,keys):
    a=[x[k] for k in keys]
    
    return torch.cat(a,axis=1)


def test(model, dataloader, device, labels):
    # test
    loss_func = torch.nn.MSELoss(reduction='mean')
    device = device

    test_loss_list = []

    test_predicted_list = []
    test_ground_list = []
    test_labels_list = []

    t_test_predicted_list = []
    t_test_ground_list = []
    t_test_labels_list = []

    test_len = len(dataloader)

    model.model.eval()

    i = 0
    acu_loss = 0
    for d in tqdm.tqdm(dataloader,total = len(dataloader)):
        x, last_state,  edge_index = d.x, d.last_state,  d.edge_index
        
        with torch.no_grad():
            predicted = model.predict(d)
            predicted = concat_dict(predicted,predicted.keys()).to(device).reshape(-1,model.n_feats)
            last_state = last_state.reshape(-1,model.n_feats)
            loss = loss_func(predicted, last_state)
            

            if len(t_test_predicted_list) <= 0:
                t_test_predicted_list = predicted
                t_test_ground_list = last_state
            else:
                t_test_predicted_list = torch.cat((t_test_predicted_list, predicted), dim=0)
                t_test_ground_list = torch.cat((t_test_ground_list, last_state), dim=0)
        
        test_loss_list.append(loss.item())
        acu_loss += loss.item()
        
        i += 1

    test_predicted_list = t_test_predicted_list
    test_ground_list = t_test_ground_list  
    if labels is None:
        test_labels_list = np.zeros_like(test_predicted_list)
        
    else:
        test_labels_list = np.zeros_like(test_predicted_list)
        test_labels_list[np.where(labels==1),:]=1
        test_labels_list = test_labels_list
        
    print(test_labels_list.shape,test_ground_list.shape,test_predicted_list.shape)
    avg_loss = sum(test_loss_list)/len(test_loss_list)

    return avg_loss, [test_predicted_list.numpy(), test_ground_list.numpy(), test_labels_list]




def test_with_normalized_loss(test_dataset, model, **kwargs):
    testloader = GraphDataLoader(test_dataset, batch_size=25, shuffle=False)
    testlosses = []
    testpercs = []
    model.model.eval()

    testpernodelosses = []
    testpernodepercentiles = []

    from sklearn.metrics import roc_auc_score
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import auc, f1_score, precision_score, recall_score
    import tqdm

    def concat(x,keys):
        a=[x[k] for k in keys]
        
        return torch.cat(a,axis=1)
    for d in tqdm.tqdm(testloader, total = len(testloader)):
        d.to(model.device)
        percentiles,loss, prediction = model.compute_novelty_scores_old(d, return_prediction=True)
        testlosses.append(loss.detach().cpu().numpy())
        
        testpercs.append(torch.tensor(percentiles['graph'][0]).detach().cpu().numpy())

        testpernodelosses.append(percentiles['losses'])
        testpernodepercentiles.append(percentiles)


    keys = list(set(['numerical','binary','categorical']).intersection(set(testpernodelosses[0].keys())))
        
    testlosses = np.concatenate(testlosses)
    testpercs = np.concatenate(testpercs)
    testfeature_losses = torch.cat([concat(pred,keys) for pred in testpernodelosses],dim=0)
    testfeature_percentiles = torch.cat([concat(pred,keys) for pred in testpernodepercentiles],dim=0)
    
    model.graph_val_losses.max()
    th = model.getgraphloss()

    print('normalized max sensor losses:')
    prds = testfeature_losses.cpu().numpy()#score_normalization(testfeature_losses.cpu().numpy())
    #prds = score_normalization(testfeature_losses.cpu().numpy())
    prds = prds.sum(axis=1)
    #prds = testlosses
    precision, recall, thresholds = precision_recall_curve(test_dataset.labels, prds)
    
    
    f1_scores = 2*recall*precision/(recall+precision+1e-5)
    if th is not None:
        f1score = f1_scores[np.where(thresholds==th)[0][0]]
        precision = precision[np.where(thresholds==th)[0][0]]
        recall = recall[np.where(thresholds==th)[0][0]]
        #print scores
        print('F1-Score: ',f1score)
        print('Precision:',precision)
        print('Recall::',recall)
    else:
        f1score = np.max(f1_scores)
        print('F1-Score: ', f1score)
        thresholded = np.array(prds)>thresholds[np.argmax(f1_scores)]
        print('Precision:',precision_score(test_dataset.labels,thresholded))
        print('Recall:',recall_score(test_dataset.labels,thresholded))   
        print('AUC: ',roc_auc_score(test_dataset.labels,prds))

    return f1score




def score_normalization(scores):
    iq= iqr(scores,axis=0)
    iq[iq==0]=1
    normed = (scores-np.median(scores,axis=0))/iq
    
    return normed
