import torch
import json
import time
import os
import argparse
from torchvision.models import resnet18

from SLDA_Model import StreamingLDA
import utils
import retrieve_any_layer
import numpy as np

from sklearn.metrics import roc_auc_score



# computes AUROC to see the effectiveness of the model, not a part of experiment
def compute_auc(in_scores, out_scores):
    if isinstance(in_scores, list):
        in_scores = np.array(in_scores)
    if isinstance(out_scores, list):
        out_scores = np.array(out_scores)

    labels = np.concatenate([np.ones_like(in_scores),
                             np.zeros_like(out_scores)])
    try:
        auc = roc_auc_score(labels, np.concatenate((in_scores, out_scores)))
    except ValueError:
        print("Input contains NaN, infinity or a value too large for dtype('float64').")
        auc = -0.99
    return auc

def get_auroc(model, val_data, ood_method, num_classes=1000):

    id_scores = []
    ood_scores = []

    num_id_classes = num_classes * 5 // 10
    num_samples = len(val_data.dataset)
    probabilities = torch.empty((num_samples, num_classes))
    labels = torch.empty(num_samples).long()
    start = 0
    image_count = 0

    cov = model.Sigma.cpu()
    cov_inv = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))
    mean_all = model.muK_all.cpu()
    cov_all = model.Sigma_all.cpu()
    cov_all_inv = np.linalg.inv(0.8 * cov_all + 0.2 * np.eye(len(cov_all)))

    start = 0

    for X, y in val_data:
        print('\rGetting OOD Scores %d/%d.' % (start, len(val_data.dataset)), end='')
        
        image_count += 1
        if feature_extraction_wrapper is not None:
            feat = feature_extraction_wrapper(X.cuda())
            feat = pool_feat(feat)
        else:
            feat = X.cuda()

        start = start + feat.shape[0]
        if ood_method == "rmd":
            scores = compute_rmd(model, feat, cov_inv, mean_all, cov_all_inv)
        elif ood_method == "md":
            scores = compute_md(model, feat, cov_inv, mean_all, cov_all_inv)
        elif ood_method == "mean":
            scores = compute_mean_distance(model, feat)
        else:
            raise NotImplementedError('This OOD method is not implemented.')
        scores = torch.Tensor(scores)
        scores = scores[...,:num_id_classes]

        sorted, indices = torch.sort(scores, descending=True)
        max_prob = sorted[0][0].item()

        # if new:
        for x_single, y_single in zip(feat, y):
            if(y_single.item() < num_id_classes):
                id_scores.append(max_prob)
            else:
                ood_scores.append(max_prob)

    id_scores = np.array(id_scores) 
    ood_scores = np.array(ood_scores) 
    auc = compute_auc(id_scores, ood_scores) 
    print("AUC: {:.4f}".format(auc))

    return probabilities, labels


def get_feature_extraction_model(ckpt_file, imagenet_pretrained=False):
    feature_extraction_model = resnet18(pretrained=imagenet_pretrained)

    if ckpt_file is not None:
        resumed = torch.load(ckpt_file)
        if 'state_dict' in resumed:
            state_dict_key = 'state_dict'
        else:
            state_dict_key = 'model_state'
        print("Resuming from {}".format(ckpt_file))
        utils.safe_load_dict(feature_extraction_model, resumed[state_dict_key])
    return feature_extraction_model


def get_feature_extraction_model_deit(pretrained_model_path):

    print("Loading Deit model pre-trained with 611 classes")
    if os.path.isfile(pretrained_model_path):
        checkpoint = torch.load(pretrained_model_path, map_location='cpu')
    else:
        raise NotImplementedError("Cannot find pre-trained model")
    
    target_model = torch.hub.load('facebookresearch/deit:main', 'deit_small_patch16_224', pretrained=False)
    
    target = target_model.state_dict()
    pretrain = checkpoint['model']
    transfer, missing = {}, []
    for k, _ in target.items():
        if k in pretrain and 'head' not in k:
            transfer[k] = pretrain[k]
        else:
            missing.append(k)
    target.update(transfer)
    target_model.load_state_dict(target)
    print("Parameters not updated: ", missing)
    return target_model

def pool_feat(features):
    feat = features.mean(1)
    return feat


def predict(model, val_data, num_classes=1000):
    num_samples = len(val_data.dataset)
    probabilities = torch.empty((num_samples, num_classes))
    labels = torch.empty(num_samples).long()
    start = 0
    with torch.no_grad():
        for X, y in val_data:
            # extract feature from pre-trained model and mean pool
            if feature_extraction_wrapper is not None:
                feat = feature_extraction_wrapper(X.cuda())
                feat = pool_feat(feat)
            else:
                feat = X.cuda()
            end = start + feat.shape[0]
            probas = model.predict(feat.cuda(), return_probas=True)
            probabilities[start:end] = probas
            labels[start:end] = y.squeeze()
            start = end
    return probabilities, labels


def md(data, mean, mat, inverse=False):
    if isinstance(data, torch.Tensor):
        data = data.data.cpu().numpy()
    if isinstance(mean, torch.Tensor):
        mean = mean.data.cpu().numpy()
    if data.ndim == 1:
        data.reshape(1, -1)
    delta = (data - mean)

    if not inverse:
        mat = np.linalg.inv(mat)

    dist = np.dot(np.dot(delta, mat), delta.T)
    return np.sqrt(np.diagonal(dist)).reshape(-1, 1)

    
def compute_md(slda_model, features, cov_inv):
    md_list = []
    if isinstance(features, torch.Tensor):
        features = features.data.cpu().numpy()
    
    for y in range(slda_model.num_classes):
        mean = slda_model.muK[y].cpu()
        dist = md(features, mean, cov_inv, inverse=True)
        scores_md = 1 / dist
        md_list.append(scores_md)

    scores_md = np.concatenate(md_list, axis=1)
    return scores_md


def compute_rmd(slda_model, features, cov_inv, mean_all, cov_all_inv):
    md_list = []
    if isinstance(features, torch.Tensor):
        features = features.data.cpu().numpy()
    
    for y in range(slda_model.num_classes):
        mean = slda_model.muK[y].cpu()
        dist = md(features, mean, cov_inv, inverse=True)
        dist_zero = md(features, mean_all, cov_all_inv, inverse=True)
        dist = dist - dist_zero

        scores_md = - dist
        md_list.append(scores_md)

    scores_md = np.concatenate(md_list, axis=1)
    return scores_md


def compute_mean_distance(slda_model, features):
    md_list = []
    if isinstance(features, torch.Tensor):
        features = features.data.cpu().numpy()
    
    for y in range(slda_model.num_classes):
        mean = slda_model.muK[y].cpu()
        mean_diff = mean - features
        diff_norm = torch.norm(mean_diff).item()
        scores_md = 1 / diff_norm
        md_list.append(scores_md)
    return [md_list]


def open_world_training(model, loaders, test_loaders_open_wold, test_loader, id_classes, num_classes, ood_method, ood_threshold, cl_threshold, is_cil):
    asking_accuracies = []
    asking_accuracies_id = []
    asking_accuracies_ood = []
    waste_asks = []
    missed_oods = []
    image_count = 0
    ood_detection_count = 0
    waste_ask_count = 0
    missed_ood_count = 0
    class_mean_convergence = [False] * num_classes
    class_mean_update_count = [0] * num_classes

    for i in range(num_classes):
        if i in id_classes:
            class_mean_update_count[i] = -1

    cov = model.Sigma.cpu()
    cov_inv = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))
    mean_all = model.muK_all.cpu()
    cov_all = model.Sigma_all.cpu()
    cov_all_inv = np.linalg.inv(0.8 * cov_all + 0.2 * np.eye(len(cov_all)))


    probas, y_test = predict(classifier, test_loader, num_classes)
    top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
    print("Accuracy before deployment:", top1_all)
    asking_accuracies.append(top1_all)
    asking_accuracies_id.append(top1_id)
    asking_accuracies_ood.append(top1_ood)
    waste_asks.append(0)
    missed_oods.append(0)
    task_stats = []   

    for curr_task, loader in enumerate(loaders): # for each task, we have a loader
        asking_accuracies_ct = []
        asking_accuracies_ct_id = []
        asking_accuracies_ct_ood = []
        waste_asks_ct = []
        missed_oods_ct = []
        ood_detection_count_ct = 0
        waste_ask_count_ct = 0
        missed_ood_count_ct = 0


        if is_cil:
            probas, y_test = predict(classifier, test_loaders_open_wold[curr_task], num_classes)
            top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
            asking_accuracies_ct.append(top1_all)
            asking_accuracies_ct_id.append(top1_id)
            asking_accuracies_ct_ood.append(top1_ood)

            waste_asks_ct.append(0)
            missed_oods_ct.append(0)

        with torch.no_grad():
            for X, y in loader:
                label = y.long()
                converged_classes = []
                emerging_classes = []
                for i in range (num_classes):
                    if(class_mean_convergence[i] == True):
                        converged_classes.append(i)
                    elif class_mean_update_count[i] >= 1:
                        emerging_classes.append(i)
                image_count += 1
                if feature_extraction_wrapper is not None:
                    feat = feature_extraction_wrapper(X.cuda())
                    feat = pool_feat(feat)
                else:
                    feat = X.cuda()

                if ood_method == "rmd":
                    scores = compute_rmd(model, feat, cov_inv, mean_all, cov_all_inv)
                elif ood_method == "md":
                    scores = compute_md(model, feat, cov_inv, mean_all, cov_all_inv)
                elif ood_method == "mean":
                    scores = compute_mean_distance(model, feat)
                else:
                    raise NotImplementedError('This OOD method is not implemented.')

                scores = scores[0]
                
                id_converged_classes = id_classes + converged_classes
                known_classes = id_converged_classes + emerging_classes

                max_prob = scores[0]
                predicted_label = 0
                for i in range(1, num_classes):
                    if scores[i] > max_prob and i in known_classes:
                        max_prob = scores[i]
                        predicted_label = i

                scores = torch.Tensor(scores)
                for x_single, y_single in zip(feat, y):
                    if max_prob < ood_threshold or predicted_label in emerging_classes:
                        if y_single.item() in id_converged_classes: 
                            waste_ask_count +=1
                            waste_ask_count_ct+=1
                        else:
                            old_mean = model.muK[label, :]
                            classifier.fit_open_world(x_single.cpu(), y_single.view(1, ))

                            if class_mean_convergence[label] == False:
                                new_mean = model.muK[label, :]
                                mean_diff = new_mean - old_mean
                                diff_norm = torch.norm(mean_diff).item()

                                if diff_norm < cl_threshold:
                                    class_mean_convergence[label] = True

                            class_mean_update_count[label] +=1

                        ood_detection_count += 1
                        ood_detection_count_ct += 1

                        # store intermediate results
                        if ood_detection_count % 100 == 0:
                            probas, y_test = predict(classifier, test_loader, num_classes)
                            top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
                            print("Predicted Image: ", image_count, ", OOD Detection:", ood_detection_count, ", Waste Ask:", waste_ask_count, ", Missed OOD:", missed_ood_count, ", Accuracy:", top1_all)
                            asking_accuracies.append(top1_all)
                            asking_accuracies_id.append(top1_id)
                            asking_accuracies_ood.append(top1_ood)
                            waste_asks.append(waste_ask_count)
                            missed_oods.append(missed_ood_count)

                        # store intermediate results for current task
                        if ood_detection_count_ct % 100 == 0 and is_cil:
                            probas, y_test = predict(classifier, test_loaders_open_wold[curr_task], num_classes)
                            top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
                            asking_accuracies_ct.append(top1_all)
                            asking_accuracies_ct_id.append(top1_id)
                            asking_accuracies_ct_ood.append(top1_ood)
                            waste_asks_ct.append(waste_ask_count_ct)
                            missed_oods_ct.append(missed_ood_count_ct)
                    elif y_single.item() not in id_converged_classes:
                        missed_ood_count += 1
                        missed_ood_count_ct += 1

        if is_cil:
            probas, y_test = predict(classifier, test_loaders_open_wold[curr_task], num_classes)
            top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
            print("Task", curr_task, "Predicted Image: ", image_count, "OOD Detection:", ood_detection_count, "Waste Ask:", waste_ask_count, "Missed OOD:", missed_ood_count, "Converged Class Update:", ", Accuracy:", top1_all)

            true_ood_count_ct = ood_detection_count_ct - waste_ask_count_ct
            task_stats.append(
                {
                    "accuracies": asking_accuracies_ct,
                    "accuracies_id": asking_accuracies_ct_id,
                    "accuracies_ood": asking_accuracies_ct_ood,
                    "waste_asks": waste_asks_ct,
                    "missed_oods": missed_oods_ct,
                    "final_accuracy": top1_all,
                    "final_accuracy_id": top1_id,
                    "final_accuracy_ood": top1_ood,
                    "final_waste_asks": waste_ask_count_ct,
                    "final_missed_ood": missed_ood_count_ct,
                    "final_ood_detection": ood_detection_count_ct,
                    "precision": (true_ood_count_ct / (true_ood_count_ct + waste_ask_count_ct) * 100),
                    "recall": (true_ood_count_ct / (true_ood_count_ct + missed_ood_count_ct) * 100)
                }
            )
                        

    true_ood_count = ood_detection_count - waste_ask_count
    precision = true_ood_count / (true_ood_count + waste_ask_count) * 100
    recall = true_ood_count / (true_ood_count + missed_ood_count) * 100

    print("Total OOD detection:", ood_detection_count)
    print("Waste Ask Count:", waste_ask_count)
    print("Missed OOD Count:", missed_ood_count)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1 Score:",  2 * ((precision * recall) / (precision + recall)))

    # intermediate results for plots
    if is_cil:
        print("\ntask_stats =", task_stats)
    else:
        print("\nwaste_asks =", waste_asks)
        print("missed_oods =", missed_oods)
        print("accuracies =", asking_accuracies)
        print("accuracies_id =", asking_accuracies_id)
        print("accuracies_ood =", asking_accuracies_ood)



def get_data_loader(images_dir, training, min_class, max_class, batch_size=128, shuffle=False, dataset='cifar10', open_world = False, is_cil=False):
    if dataset == 'cifar10':
        print("Using CIFAR-10 dataset...")
        return utils.get_cifar10_loader(images_dir, min_class, max_class, training, batch_size=batch_size,
                                         shuffle=shuffle, open_world=open_world, is_cil=is_cil)
    elif dataset == 'cifar100':
        print("Using CIFAR-100 dataset...")
        return utils.get_cifar100_loader(images_dir, min_class, max_class, training, batch_size=batch_size,
                                         shuffle=shuffle, open_world=open_world, is_cil=is_cil)
    elif dataset == 'tinyimagenet':
        print("Using TinyImageNet dataset...")
        return utils.get_tinyimagenet_loader(images_dir, min_class, max_class, training, batch_size=batch_size,
                                         shuffle=shuffle, open_world=open_world, is_cil=is_cil)
    else:
        #### IMPLEMENT ANOTHER DATASET HERE ####
        raise NotImplementedError('Please implement another dataset.')

def run_experiment(dataset, images_dir, classifier, feature_extraction_wrapper, feature_size, batch_size,
                   shuffle, is_cil, num_classes, ood_method, ood_threshold, cl_threshold):


    start_time = time.time()

    first_time = True  # true for base init stage

    num_id_classes = num_classes * 5 // 10

    id_classes = []

    print("\nTraining classes from {} to {}".format(0, num_id_classes))

    # get training loader for current batch
    train_loader = get_data_loader(images_dir, True, 0, num_id_classes, batch_size=batch_size,
                                    shuffle=shuffle, dataset=dataset, is_cil=is_cil)
    if first_time:
        print('\nGetting data for base initialization...')

        # initialize arrays for base init data because it must be provided all at once to SLDA
        base_init_data = torch.empty((len(train_loader.dataset), feature_size))
        base_init_labels = torch.empty(len(train_loader.dataset)).long()

        # put features into array since base init needs all features at once
        start = 0
        with torch.no_grad():
            for batch_x, batch_y in train_loader:
                for label in batch_y.squeeze():
                    if label not in id_classes:
                        id_classes.append(label.item())
                print('\rLoading features %d/%d.' % (start, len(train_loader.dataset)), end='')
                # get feature in real-time
                if feature_extraction_wrapper is not None:
                    # extract feature from pre-trained model and mean pool
                    batch_x_feat = feature_extraction_wrapper(batch_x.cuda())
                    batch_x_feat = pool_feat(batch_x_feat)

                else:
                    batch_x_feat = batch_x.cuda()
                end = start + batch_x_feat.shape[0]
                base_init_data[start:end] = batch_x_feat
                base_init_labels[start:end] = batch_y.squeeze()
                start = end

        # fit base initialization stage
        print('\nFirst time...doing base initialization...')
        classifier.fit_base(base_init_data, base_init_labels)

        first_time = False
    test_loader = get_data_loader(images_dir, False, 0, num_classes, batch_size=batch_size, shuffle=shuffle,
                                  dataset=dataset, is_cil=is_cil)


    # open world training stage
    print('\n\nOpen world training stage...')
    loaders_open_world = get_data_loader(images_dir, True, 0, num_classes, batch_size=1,
                                       shuffle=shuffle, dataset=dataset, open_world=True, is_cil=is_cil)
    
    num_task = 5
    num_cls_per_task = (num_classes - num_id_classes) // num_task

    test_loaders_open_wold = []
    for curr_task in range(num_task):
        test_loader_open_wold = get_data_loader(images_dir, False, 0, num_id_classes + ((curr_task + 1) * num_cls_per_task), batch_size=batch_size, shuffle=shuffle,
                                  dataset=dataset, is_cil=is_cil)
        test_loaders_open_wold.append(test_loader_open_wold)

    print("id classes:", id_classes)
    open_world_training(classifier, loaders_open_world, test_loaders_open_wold, test_loader, id_classes, num_classes, ood_method, ood_threshold, cl_threshold, is_cil)

    # print final accuracies and time
    probas, y_test = predict(classifier, test_loader, num_classes)
    top1_id ,top1_ood, top1_all = utils.id_ood_accuracy(probas, y_test, id_classes)
    end_time = time.time()
    print('\nFinal: Total accuracy=%0.2f%% -- ID classes accuracy=%0.2f%% -- OOD classes accuracy=%0.2f%%' % (top1_all, top1_id, top1_ood))
    print('\nTotal Time (seconds): %0.2f' % (end_time - start_time))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # experiment parameters
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet'])
    parser.add_argument('--images_dir', type=str)  # path to images (folder with 'train' and 'val' for imagenet)

    parser.add_argument('--num_classes', type=int, default=10)  # total number of classes in the dataset
    parser.add_argument('--batch_size', type=int, default=256)  # batch size for getting features & testing
    parser.add_argument('--input_feature_size', type=int, default=384)  # resnet-18 feature size
    parser.add_argument('--shuffle_data', action='store_true')  # true to shuffle data (usually don't want this)

    # SLDA parameters
    parser.add_argument('--streaming_update_sigma', action='store_true')  # true to update covariance online
    parser.add_argument('--shrinkage', type=float, default=1e-4)  # shrinkage for SLDA

    # OpenLD parameters
    parser.add_argument('--pretrained_model_path', type=str, default='./best_checkpoint.pth')
    parser.add_argument('--experimental_setting', type=str, default='random', choices=['random', 'cil'])
    parser.add_argument('--ood_method', type=str, default='rmd', choices=['rmd', 'md', 'mean'])
    parser.add_argument('--ood_threshold', type=float, default=0.0)
    parser.add_argument('--cl_threshold', type=float, default=0.1)

    args = parser.parse_args()
    print("Arguments {}".format(json.dumps(vars(args), indent=4, sort_keys=True)))

    # setup SLDA model
    classifier = StreamingLDA(args.input_feature_size, args.num_classes, test_batch_size=args.batch_size,
                              shrinkage_param=args.shrinkage, streaming_update_sigma=args.streaming_update_sigma)


    feature_extraction_model = get_feature_extraction_model_deit(args.pretrained_model_path)
    feature_extraction_wrapper = retrieve_any_layer.ModelWrapper(feature_extraction_model.cuda(), ['blocks.11'],
                                                                    return_single=True).eval()

    is_cil = False if args.experimental_setting == "random" else True

    # run the streaming experiment
    run_experiment(args.dataset, args.images_dir, classifier, feature_extraction_wrapper,
                   args.input_feature_size, args.batch_size, args.shuffle_data, is_cil, args.num_classes, args.ood_method, args.ood_threshold, args.cl_threshold)
