import os
import torch
import numpy as np
import pickle
from load_data import validation_collator, action_recognition_kwargs
from options import get_args, prettyprint_args
from inference import evaluate, transform_evaluate, attach_hooks, report_hook_results

import datetime
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, accuracy_score, precision_recall_fscore_support, confusion_matrix
import Resnets3D.models.resnet as Resnets3D
from torch.utils.data import DataLoader
from datasets.raw_video_dataset import HMDB51Dataset
from perturbation_experiments import get_file_corruption_dataloaders, get_network_corruption_dataloaders

hmdb51_path = "models/hmdb51/2020_05_03_21_27_17.pth"
ucf101_path = "models/ucf101/2020_05_10_23_20_09.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_train_eval_dataloader(args):
    fpath = os.path.join(args.base_path, args.img_folder)
    dataset_args = action_recognition_kwargs(args)
    dataset_args['transform'] = None
    data = HMDB51Dataset(os.path.join(fpath, "hmdb51"), 
        os.path.join(fpath, "train_test_splits"), 'test', args.load_width, 
        args.load_height, split_id=1, **dataset_args)
    dl = DataLoader(data, batch_size=1, num_workers=args.num_workers, shuffle=False, collate_fn=validation_collator)
    return dl

def get_percentile(args, dl, model):
    key = (args.dataset, args.percentile, args.temperature, args.mean_subtract)
    if os.path.isfile("ood_lookup.pkl"):
        with open("ood_lookup.pkl", "rb") as f:
            lookup_dict = pickle.load(f)
            if key in lookup_dict and not args.force_recalculate_threshold:
                print("Key found!")
                threshold, scores = lookup_dict[key]
            else:
                print("Key not found...")
                threshold, scores = train_eval(args, dl, model)
                lookup_dict[key] = (threshold, scores)
        with open("ood_lookup.pkl", "wb") as g:
            pickle.dump(lookup_dict, g)
    else:
        threshold, scores = train_eval(:q
                args, dl, model)
        lookup_dict = {key: (threshold, scores)}
        with open("ood_lookup.pkl", "wb") as g:
            pickle.dump(lookup_dict, g)
    return threshold, scores


def get_results(args, test_dataloader, perturbed_dataloader, model, hooks):
    if args.experiment_mode == 'network':
        key = (args.dataset, args.experiment_mode, args.temperature, args.packet_loss_rate, args.mean_subtract, args.corruption_version)
    else:
        key = (args.dataset, args.experiment_mode, args.temperature, args.corrupt_mode, args.corrupt_prob) 
    if os.path.isfile("score_lookup.pkl"):
        with open("score_lookup.pkl", "rb") as f:
            lookup_dict = pickle.load(f)
            if key in lookup_dict and not args.force_recalculate_scores:
                print("Found score cache")
                aug_preds, aug_y, aug_scores, hook_results = lookup_dict[key]
            else:
                print("Score cache not found, regenerating...")
                _, aug_preds, aug_y, aug_scores, hook_results = transform_evaluate(test_dataloader, perturbed_dataloader, model, transform=None, restore_training=False, leave=True, temperature=args.temperature, mean_sub=True, data_hooks=hooks)

                lookup_dict[key] = (aug_preds, aug_y, aug_scores, hook_results)
        with open("score_lookup.pkl", "wb") as g:
            pickle.dump(lookup_dict, g)
    else:
        _, aug_preds, aug_y, aug_scores, hook_results = transform_evaluate(test_dataloader, perturbed_dataloader, model, transform=None, restore_training=False, leave=True, temperature=args.temperature, mean_sub=True, data_hooks=hooks)
        lookup_dict = {key: (aug_preds, aug_y, aug_scores, hook_results)}
        with open("score_lookup.pkl", "wb") as g:
            pickle.dump(lookup_dict, g)
    return aug_preds, aug_y, aug_scores, hook_results


def train_eval(args, dl, model):
    print("Evaluating on test split 1:")
    _, preds, y, scores = evaluate(dl, model, restore_training=False, leave=True, temperature=args.temperature)
    threshold = np.percentile(scores, 100 - args.percentile)
    return threshold, scores


def load_model(args):
    if args.dataset == 'hmdb51':
        PATH = hmdb51_path
    elif args.dataset == 'ucf101':
        PATH = ucf101_path
    else:
        raise NotImplementedError()
    model = Resnets3D.resnet18(num_classes=args.n_classes, sample_size=args.load_width, sample_duration=args.max_frames)
    model.load_state_dict(torch.load(PATH))
    model = model.to(device)
    model.eval()
    return model


def report_ood_detection(aug_preds, aug_y, aug_scores, hook_results, results_dict):
    readable = (aug_scores >= 0)
    total = np.count_nonzero(readable)
    detects = np.count_nonzero((aug_scores <= threshold) & readable)
    misses = np.count_nonzero((aug_scores > threshold) & readable)
    print("OOD Detections: {} ({:.2%})".format(detects, detects / total))
    print("OOD Misses: {} ({:.2%})".format(misses, misses / total))
    print("Total:", total)
    result_dict['aug_preds'] = aug_preds
    result_dict['aug_y'] = aug_y
    result_dict['aug_scores'] = aug_scores
    return result_dict

def report_auc_metrics(scores, aug_scores, result_dict):
    print()
    print("="  * 80)
    print("PART 2: AUROC/AUPR on IN/OUT binary classification task")
    print()
    print("Perturbed = OOD (0), Unperturbed = in-distribution (1)")
    print("=" * 80)
    ood_labels = np.concatenate([np.ones_like(scores), np.zeros_like(aug_scores)])
    ood_probs = np.concatenate([scores, aug_scores])
    precision, recall, _ = precision_recall_curve(ood_labels, ood_probs)
    print("AUROC:", roc_auc_score(ood_labels, ood_probs))
    print("AUPR:", auc(recall, precision))
    result_dict['ood_labels'] = ood_labels
    result_dict['ood_probs'] = ood_probs
    result_dict['ood_precision'] = precision
    result_dict['ood_recall'] = recall
    return result_dict

def report_ood_filtered_metrics(aug_preds, aug_y, aug_scores, threshold, result_dict):
    print()
    print("=" * 80)
    print("PART 3: Acc/P/R/F metrics on OOD-filtered input")
    print("=" * 80)
    in_dist = (aug_scores > threshold)
    filtered_preds = aug_preds[in_dist]
    filtered_y = aug_y[in_dist]
    print("Support:", len(filtered_preds))
    acc = accuracy_score(filtered_y, filtered_preds)
    print("Unfiltered accuracy:", accuracy_score(aug_y, aug_preds))
    print("Accuracy:", acc)
    precision, recall, f1, _ = precision_recall_fscore_support(filtered_y, filtered_preds, average='macro')
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1:", f1)
    result_dict['filtered_preds'] = filtered_preds
    result_dict['filtered_y'] = filtered_y
    result_dict['filtered_acc'] = acc
    result_dict['filtered_precision'] = precision
    result_dict['filtered_recall'] = recall
    result_dict['filtered_f1'] = f1
    if 'TruncatedL2' in result_dict['hook_results']:
        l2 = np.array(result_dict['hook_results']['TruncatedL2'])
        valid = (l2 >= 0)
        correct = (np.array(aug_y) == np.array(aug_preds))
        print("L2:", l2[valid].mean())
        print("L2 (correct):", l2[valid & in_dist & correct].mean())
        print("L2 (incorrect):", l2[valid & in_dist & ~correct].mean())
        print("L2 (ID):", l2[valid & in_dist].mean())
        print("L2 (OOD):", l2[valid & ~in_dist].mean())
    return result_dict

def save_results(args, result_dict):
    time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
    save_dir = os.path.join(args.base_path, 'ood_experiments/')
    fname = os.path.join(save_dir, time_str + ".pkl")
    with open(fname, 'wb') as f:
        pickle.dump(result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
    print("Results log saved to ", fname)

if __name__ == '__main__':
    args = get_args()
    print(prettyprint_args(args))

    """
        print(X.size(), Xt.size())
        Metric 1: threshold value + detections
    """
    print("=" * 80)
    print("PART 1: Threshold value determination and OOD detection")
    print("=" * 80)
    print()
    print("Loading model...")
    model = load_model(args)
    train_dataloader = get_train_eval_dataloader(args)
    result_dict = {'args': args, 'aug_info': None, 'hook_results': {}}
    threshold, scores = get_percentile(args, train_dataloader, model)
    result_dict['threshold'] = threshold
    result_dict['scores'] = scores
    print("{:.2%} TPR threshold: {}".format(args.percentile / 100, threshold))
    args.mean_subtract = False
    if args.experiment_mode == 'file':
        test_dataloader, perturbed_dataloader = get_file_corruption_dataloaders(args)
    elif args.experiment_mode == 'network':
        test_dataloader, perturbed_dataloader = get_network_corruption_dataloaders(args)
    else:
        raise NotImplementedError()
    hooks = attach_hooks(args)
    aug_preds, aug_y, aug_scores, hook_results = get_results(args, test_dataloader, perturbed_dataloader, model, hooks)
    cm = confusion_matrix(aug_y, aug_preds)
    result_dict['cm'] = cm
    result_dict['y'] = aug_y
    result_dict['preds'] = aug_preds
    for hook_name, hook in hooks.items():
        result_dict['hook_results'][hook_name] = hook.results
    aug_preds, aug_y, aug_scores = np.array(aug_preds), np.array(aug_y), np.array(aug_scores)
    result_dict = report_ood_detection(aug_preds, aug_y, aug_scores, hook_results, result_dict)

    """
        Metric 2: AUROC/AUPR on IN/OUT binary classification
    """

    result_dict = report_auc_metrics(scores, aug_scores, result_dict)

    """
        Metric 3: Acc/P/R/F on OOD-filtered input
    """
    result_dict = report_ood_filtered_metrics(aug_preds, aug_y, aug_scores, threshold, result_dict)
    
    """
        Metric 4: Precision on proxy task
    """
    print()
    print("Proxy task: prediction change detection")

    readable = (np.array(aug_scores) >= 0)
    baseline = pickle.load(open('ood_experiments/2020_09_13_22_40_04.pkl', 'rb'))
    gt = ((baseline['preds'] != result_dict['preds']) & readable)
    ood = ((aug_scores < result_dict['threshold']) & readable)
    print("Support:", np.count_nonzero(gt))
    print("Proxy acc:", accuracy_score(gt, ood))
    print("P/R/F/S", precision_recall_fscore_support(gt, ood, average='binary'))
    if not args.lightweight:
        save_results(args, result_dict)
    #report_hook_results(hooks, result_dict, report_acc=False)
