from benchmarking_utils.ncc_classifiers import NCCClassifierHandler
from models.ssl_models import VICRegModel, classes
import torch
import pytorch_lightning as pl
import os
from glob import glob
from tqdm import tqdm
import json
import random
import yaml

print(f"torch version: {torch.__version__}")


def run_ncc_preds(feature_bank, targets_bank, from_layer, end_layer,
                  feature_bank_test, targets_bank_test):
    print(f"running linear classifier on train")
    linear_handler = NCCClassifierHandler(feature_bank, targets_bank,
                                          from_layer=from_layer, end_layer=end_layer)

    print(f"running NCC classifiers")
    train_predictions = linear_handler.predict_ncc(feature_bank, targets_bank)
    print(f"running NCC classifier on test")
    test_predictions = linear_handler.predict_ncc(feature_bank_test, targets_bank_test)
    return train_predictions, test_predictions


def get_random_targets(rand_targets_folder):
    epochs = [10, 30, 100, 500, 1000]
    target_files = glob(os.path.join(rand_targets_folder, "*"))
    targets = {}
    for file_path in target_files:
        epoch = int(file_path.split('/')[-1].split('_')[1].split('.')[0])
        if epoch not in epochs:
            continue
        targets[epoch] = json.load(open(file_path, "r"))
    return targets


def load_hparams(output_path):
    hparams_path = f"{output_path}/version_0/hparams.yaml"
    with open(hparams_path, 'r') as config_file:
        return yaml.full_load(config_file)


def run_ncc_evaluation(checkpoint_folder, output_path, rand_targets_folder=None):
    from offline_evaluation.load_datasets import load_datasets
    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = load_datasets(augment=False)
    pl.seed_everything(0)

    hparams = load_hparams(checkpoint_folder)
    print(f"hparams:{hparams}")

    gpus = 1 if torch.cuda.is_available() else 0
    random_targets = None
    if rand_targets_folder:
        random_targets = get_random_targets(rand_targets_folder)

    interesting_epoch_list = [19, 39, 99, 399, 999]

    checkpoint_files = list(glob(os.path.join(checkpoint_folder, "*")))
    random.Random(203).shuffle(checkpoint_files)

    if not os.path.isdir(output_path):
        os.mkdir(output_path)

    benchmark_model = VICRegModel(dataloader_kNN=dataloader_train_eval,
                                  dataloader_test=dataloader_test_eval,
                                  num_classes=classes,
                                  resnet_type=hparams['resnet_type'],
                                  depth=hparams['depth'], width=hparams['width']).cuda()

    print(f"benchmark_model:{benchmark_model}, device:{benchmark_model.dummy_param.device}")
    from_layer = 0
    to_layer = 5

    for checkpoint_path in tqdm(checkpoint_files, total=len(checkpoint_files)):
        epoch = checkpoint_path.split('/')[-1].split('=')[1].split('-')[0]
        if int(epoch) not in interesting_epoch_list:
            continue

        new_output_path = os.path.join(output_path, f"CHECKPOINT_{epoch}.json")

        if os.path.isfile(new_output_path):
            print(f"SKIPPING PATH: {new_output_path}, exists!")
            continue

        print(f"SSL checkpoint on {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location='cpu')['state_dict']
        benchmark_model.load_state_dict(checkpoint)
        results = {}

        benchmark_model.eval()

        # getting features on both train and test sets
        benchmark_model.training_epoch_end(outputs=None, is_offline=True, from_layer=from_layer, mode='eval')
        benchmark_model.training_epoch_end(outputs=None, is_offline=True, from_layer=from_layer, use_test=True,
                                           mode='eval')

        # Run for GT targets
        train_predictions, test_predictions = run_ncc_preds(benchmark_model.feature_bank,
                                                            benchmark_model.targets_bank,
                                                            from_layer, to_layer,
                                                            benchmark_model.feature_bank_test,
                                                            benchmark_model.targets_bank_test)
        results['gt'] = {}
        results['gt']['train'] = train_predictions
        results['gt']['test'] = test_predictions

        for key, targets in random_targets.items():
            print(f"random targets: {key}")
            random_epoch = key
            train_targets, test_targets = torch.tensor(targets['train']), torch.tensor(targets['test'])
            cur_targets_train = {idx: train_targets for idx in range(from_layer, 5)}
            cur_targets_test = {idx: test_targets for idx in range(from_layer, 5)}

            train_predictions, test_predictions = run_ncc_preds(benchmark_model.feature_bank,
                                                                cur_targets_train,
                                                                from_layer, to_layer,
                                                                benchmark_model.feature_bank_test,
                                                                cur_targets_test)

            results[random_epoch] = {}
            results[random_epoch]['train'] = train_predictions
            results[random_epoch]['test'] = test_predictions

        print(results)
        json.dump(results, open(new_output_path, 'w'))
        print(f"saved to: {new_output_path}")
