import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import numpy as np
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from pprint import pprint
from datasets.data import load_dataset
from losses import get_loss_helper, CrossEntropyHelper
from metrics.osr import evaluation
from utils.utils import actualize_centers


from config_flags import load_flags
from absl import flags

RES_DIR = 'results/benchmark_osr/'

################################################################################
# Below options need models to be rerun on some data

# if True, will compute new centers for each class and evaluate using those
# rather than the original anchors
CALCULATE_CAC_CENTERS = True

# if True, will compute new centers from representations and evaluate using those
# rather than logits output
CALCULATE_CE_CENTERS = False

# if True, will use the softmax score to compute auroc for cross entropy models
USE_SOFTMAX_SCORE = False

################################################################################

models = os.listdir(RES_DIR)
path_to_models = [os.path.join(RES_DIR, model) for model in models]


def actualize_cac_results(save_path):
    load_flags(os.path.join(save_path, 'flags.txt'))
    args = flags.FLAGS
    
    tf.keras.utils.set_random_seed(args.seed)

    # Load dataset
    datasets, nb_classes, nb_batches, nb_channels, norm_layer, _ = load_dataset(
        args,
        parallel_strategy=None,
        data_augmentation=False
    )
    
    class_anchors = tf.repeat(tf.eye(nb_classes), args.nb_features, axis=1)
    class_anchors *= args.anchor_multiplier
    
    loss_helper = get_loss_helper(args, class_anchors, nb_classes)
    
    model_path = os.path.join(args.save_path, args.prefix) + "model.save"

    model = tf.keras.models.load_model(
        model_path
    )
    
    # need to unbatch everything cause otherwise dataset isn't going to be
    # seen deterministically
    images, labels = zip(*datasets["ds_train_known"].unbatch())
    images = np.array(images)
    labels = np.array(labels)
    
    preds = model.predict(images, batch_size=args.batch_size)
    predicted_label = loss_helper.predicted_class(preds)
    correct_preds_indices = predicted_label == labels
    correct_preds = preds[correct_preds_indices]
    correct_labels = labels[correct_preds_indices]
    
    # compute cac centers
    class_centers = []
    for i in range(nb_classes):
        class_centers.append(np.mean(correct_preds[correct_labels == i], axis=0))
    print("CAC centers computed")
    print(class_centers)    
    
    ##### TEST on new centers
    new_loss_helper = get_loss_helper(args, class_centers, nb_classes)
    
    images_test_k, labels = zip(*datasets["ds_test_known"].unbatch())
    images_test_k = np.array(images_test_k)
    labels_test_k = np.array(labels)
    
    pred_known = model.predict(images_test_k, batch_size=args.batch_size)
    pred_unknown = model.predict(datasets["ds_test_unknown"])

    print("Test results :")                   
    print("\tUsing original anchors:")
    results_old = evaluation(pred_known, pred_unknown, labels_test_k, loss_helper)
    print("\n\tUsing updated anchors:")
    results = evaluation(pred_known, pred_unknown, labels_test_k, new_loss_helper)

    return {"accuracy": results["acc"]*100., 
            "real_auroc": results["real_auroc"], 
            "max_val_auroc": results["max_val_auroc"], 
            "oscr": results["oscr"]
        }   
    
def crossentropy_representations(save_path):
    load_flags(os.path.join(save_path, 'flags.txt'))
    args = flags.FLAGS
    
    tf.keras.utils.set_random_seed(args.seed)

    # Load dataset
    datasets, nb_classes, nb_batches, nb_channels, norm_layer, _ = load_dataset(
        args,
        parallel_strategy=None,
        data_augmentation=False
    )
    
    class_anchors = tf.repeat(tf.eye(nb_classes), args.nb_features, axis=1)
    class_anchors *= args.anchor_multiplier
    
    loss_helper = get_loss_helper(args, class_anchors, nb_classes)
    
    model_path = os.path.join(args.save_path, args.prefix) + "model.save"

    model = tf.keras.models.load_model(
        model_path
    )
    
    features_extractor = tf.keras.models.Model(
            inputs=model.inputs,
            outputs=model.get_layer(name="features_layer").output,
        )
    
    # need to unbatch everything cause otherwise dataset isn't going to be
    # seen deterministically
    images, labels = zip(*datasets["ds_train_known"].unbatch())
    images = np.array(images)
    labels = np.array(labels)
    
    ce_centers = actualize_centers(model, features_extractor, images, labels,
                                   loss_helper, nb_classes)
    print("Cross entropy centers computed")
    print(ce_centers)   
    
    ##### TEST on new centers
    new_loss_helper = CrossEntropyHelper(osr_score="min", use_softmax=True)
    new_loss_helper.use_class_anchors(ce_centers)
    
    images_test_k, labels = zip(*datasets["ds_test_known"].unbatch())
    images_test_k = np.array(images_test_k)
    labels_test_k = np.array(labels)
    
    pred_known = features_extractor.predict(images_test_k, batch_size=args.batch_size)
    pred_unknown = features_extractor.predict(datasets["ds_test_unknown"])

    print("Test results :")                   
    print("\n\tUsing distance to cross entropy centers:")
    results = evaluation(pred_known, pred_unknown, labels_test_k, new_loss_helper)

    return {"accuracy": results["acc"], 
            "real_auroc": results["real_auroc"], 
            "max_val_auroc": results["max_val_auroc"], 
            "oscr": results["oscr"]
        } 
    
def softmax_auroc(save_path):
    load_flags(os.path.join(save_path, 'flags.txt'))
    args = flags.FLAGS
    
    tf.keras.utils.set_random_seed(args.seed)

    # Load dataset
    datasets, nb_classes, nb_batches, nb_channels, norm_layer, _ = load_dataset(
        args,
        parallel_strategy=None,
        data_augmentation=False
    )
    
    class_anchors = tf.repeat(tf.eye(nb_classes), args.nb_features, axis=1)
    class_anchors *= args.anchor_multiplier
        
    model_path = os.path.join(args.save_path, args.prefix) + "model.save"

    model = tf.keras.models.load_model(
        model_path
    )
    
    new_loss_helper = CrossEntropyHelper(osr_score="max", use_softmax=True)
    
    images_test_k, labels = zip(*datasets["ds_test_known"].unbatch())
    images_test_k = np.array(images_test_k)
    labels_test_k = np.array(labels)
    
    pred_known = model.predict(images_test_k, batch_size=args.batch_size)
    pred_unknown = model.predict(datasets["ds_test_unknown"])

    print("Test results :")                   
    print("\n\tUsing softmax score:")
    results = evaluation(pred_known, pred_unknown, labels_test_k, new_loss_helper)

    return {"accuracy": results["acc"], 
            "real_auroc": results["real_auroc"], 
            "max_val_auroc": results["max_val_auroc"], 
            "oscr": results["oscr"]
        } 

general_dict = {}
for model, path_to_model in zip(models, path_to_models):
    losses = os.listdir(path_to_model)
    path_to_losses = [os.path.join(path_to_model, loss) for loss in losses]
    for loss, path_to_loss in zip(losses, path_to_losses):
        datasets = os.listdir(path_to_loss)
        path_to_datasets = [os.path.join(path_to_loss, dataset) for dataset in datasets]
        for dataset, path_to_dataset in zip(datasets, path_to_datasets):
            general_dict[(model, loss, dataset)] = {
                "path": path_to_dataset,
            }
        
# pprint(general_dict)
    

for (_, loss, dataset_name), value in general_dict.items():
    splits = os.listdir(value["path"])
    splits = [split for split in splits if '.sh' not in split]
    path_to_splits = [os.path.join(value["path"], split) for split in splits]
    
    value["splits"] = {}
    for split, path_to_split in zip(splits, path_to_splits):
        split_index = int(split.split('-')[-1])
        value["splits"][split_index] = {}

        nb_features = os.listdir(path_to_split)
        path_to_nb_features = [os.path.join(path_to_split, nb_feature) for nb_feature in nb_features]
        
        for nb_f, path_nb_f in zip(nb_features, path_to_nb_features):
            anchor_multipliers = os.listdir(path_nb_f)
            path_to_anchor_multipliers = [os.path.join(path_nb_f, anchor_multiplier) for anchor_multiplier in anchor_multipliers]
            
            for am, path_am in zip(anchor_multipliers, path_to_anchor_multipliers):
                max_dists = os.listdir(path_am)
                path_to_max_dists = [os.path.join(path_am, max_dist) for max_dist in max_dists]
                
                for md, path_md in zip(max_dists, path_to_max_dists):
                    history = pickle.load(open(os.path.join(path_md, 'history.pkl'), 'rb'))
                    #TODO: loop in case there were multiple runs for those parameters
                    
                    # default values
                    final_acc = history['test_accuracy'][-1] * 100.
                    final_auroc = history['real_auroc'][-1]
                    final_max_val_auroc = history['max_val_auroc'][-1]
                    final_oscr = history['oscr'][-1]
                    
                    run_values = {
                        "accuracy": final_acc.numpy(),
                        "real_auroc": final_auroc,
                        "max_val_auroc": final_max_val_auroc,
                        "oscr": final_oscr,
                    }
                    
                    # replace default values if needed
                    if CALCULATE_CAC_CENTERS and loss == 'cac':
                        print(path_md)
                        run_values = actualize_cac_results(path_md)
                    if CALCULATE_CE_CENTERS and loss == 'crossentropy':
                        print(path_md)
                        run_values = crossentropy_representations(path_md)
                    if USE_SOFTMAX_SCORE and loss == 'crossentropy':
                        print(path_md)
                        run_values = softmax_auroc(path_md)
                    

                    
                    if loss == 'cac':
                        print(run_values)
                        print()
                    
                    if (nb_f, am, md) not in value["splits"][split_index]:
                        value["splits"][split_index][(nb_f, am, md)] = [run_values]
                    else:
                        value["splits"][split_index][(nb_f, am, md)].append(run_values)
                        

for key, value in general_dict.items():
    value["splits_mean"] = {}
    value["splits_std"] = {}
    for split_index, split_values in value["splits"].items():
        
        # for each run params on this split
        for run_params, run_values in split_values.items():
            if run_params not in value["splits_mean"]:
                # TODO: loop on run values in case there were multiple runs for those parameters
                value["splits_mean"][run_params] = {
                    "accuracy": [run_values[0]["accuracy"]],
                    "real_auroc": [run_values[0]["real_auroc"]],
                    "max_val_auroc": [run_values[0]["max_val_auroc"]],
                    "oscr": [run_values[0]["oscr"]],
                }
            else:
                value["splits_mean"][run_params]["accuracy"].append(run_values[0]["accuracy"])
                value["splits_mean"][run_params]["real_auroc"].append(run_values[0]["real_auroc"])
                value["splits_mean"][run_params]["max_val_auroc"].append(run_values[0]["max_val_auroc"])
                value["splits_mean"][run_params]["oscr"].append(run_values[0]["oscr"])
    
    # Compute mean and std for each run params
    for run_params, run_values in value["splits_mean"].items():
        if run_params not in value["splits_std"]:
            value["splits_std"][run_params] = {
                "accuracy": [],
                "real_auroc": [],
                "max_val_auroc": [],
                "oscr": [],
            }
        value["splits_std"][run_params]["accuracy"] = np.std(run_values["accuracy"])
        value["splits_std"][run_params]["real_auroc"] = np.std(run_values["real_auroc"])
        value["splits_std"][run_params]["max_val_auroc"] = np.std(run_values["max_val_auroc"])
        value["splits_std"][run_params]["oscr"] = np.std(run_values["oscr"])
        
        value["splits_mean"][run_params]["accuracy"] = np.mean(run_values["accuracy"])
        value["splits_mean"][run_params]["real_auroc"] = np.mean(run_values["real_auroc"])
        value["splits_mean"][run_params]["max_val_auroc"] = np.mean(run_values["max_val_auroc"])
        value["splits_mean"][run_params]["oscr"] = np.mean(run_values["oscr"])
        

pprint(general_dict)

rows_labels = []
data = {
    "accuracy": {},
    "real_auroc": {},
    "max_val_auroc": {},
    "oscr": {},
}

for (model,loss,dataset), value in general_dict.items():
    if loss not in rows_labels:
        rows_labels.append(loss)
        
    if dataset not in data["accuracy"]:
        data["accuracy"][dataset] = []
    if dataset not in data["real_auroc"]:
        data["real_auroc"][dataset] = []
    if dataset not in data["max_val_auroc"]:
        data["max_val_auroc"][dataset] = []
    if dataset not in data["oscr"]:
        data["oscr"][dataset] = []
        
    for run_params, run_values in value["splits_mean"].items():
        data["accuracy"][dataset].append(run_values["accuracy"])
        data["real_auroc"][dataset].append(run_values["real_auroc"])
        data["max_val_auroc"][dataset].append(run_values["max_val_auroc"])
        data["oscr"][dataset].append(run_values["oscr"])

stds = {
    "accuracy": {},
    "real_auroc": {},
    "max_val_auroc": {},
    "oscr": {},
}

for (model, loss, dataset), value in general_dict.items():
    if dataset not in stds["accuracy"]:
        stds["accuracy"][dataset] = []
    if dataset not in stds["real_auroc"]:
        stds["real_auroc"][dataset] = []
    if dataset not in stds["max_val_auroc"]:
        stds["max_val_auroc"][dataset] = []
    if dataset not in stds["oscr"]:
        stds["oscr"][dataset] = []
        
    for run_params, run_values in value["splits_std"].items():
        stds["accuracy"][dataset].append(run_values["accuracy"])
        stds["real_auroc"][dataset].append(run_values["real_auroc"])
        stds["max_val_auroc"][dataset].append(run_values["max_val_auroc"])
        stds["oscr"][dataset].append(run_values["oscr"])

print(stds)
          
print("Mean results:")
for key, value in data.items():
    print("Metric:", key)
    df = pd.DataFrame.from_dict(value)
    df.index = rows_labels
    df = df[["mnist", "svhn", "cifar10", "cifar+10", "cifar+50", "tiny_imagenet"]]
    df.columns = [i.upper() for i in df.columns]
    print(df)
    print("Latex:")
    df=df.round(1)
    print(df.to_latex(float_format="%.1f"))
    print()
print()

print("Std results:")
for key, value in stds.items():
    print("Metric:", key)
    df = pd.DataFrame.from_dict(value)
    df.index = rows_labels
    df = df[["mnist", "svhn", "cifar10", "cifar+10", "cifar+50", "tiny_imagenet"]]
    df.columns = [i.upper() for i in df.columns]
    print(df)
    print("Latex:")
    df=df.round(1)
    print(df.to_latex(float_format="%.1f"))
    print()
    
    
#                   MNIST       SVHN    CIFAR10   CIFAR+10   CIFAR+50  TINY_IMAGENET
# dist          99.837807  97.987930  95.784996  96.159996  96.224998      66.840004
# crossentropy  99.788864  97.749283  96.089996  96.120003  96.134995      61.359997
# cac           99.813332  98.181381  96.074997  96.180008  96.300003      64.519997
    
    
    
# cac & 99.8 & 98.2 & 96.1 & 96.2 & 96.3 & 64.5 \\ acc
# cac & 99.8 & 98.2 & 96.1 & 96.2 & 96.3 & 66.1
# cac & 98.8 & 96.5 & 85.0 & 85.6 & 85.9 & 63.7 \\ auroc
# cac & 98.9 & 96.6 & 86.1 & 86.5 & 87.0 & 58.2 \\
