from benchmarking_utils.linear_classifiers import LinearClassifierHandler
from models.classifier import Classifier
from models.ssl_models import VICRegModel, classes, SimCLRModel
import torch
import pytorch_lightning as pl
import os
from glob import glob
from tqdm import tqdm
import json
import random
import yaml

print(f"torch version: {torch.__version__}")


def run_lstsq_preds(feature_bank, targets_bank, from_layer, end_layer,
                    feature_bank_test, targets_bank_test):
    print(f"running linear classifier on train")
    linear_handler = LinearClassifierHandler(feature_bank, targets_bank,
                                             from_layer=from_layer, end_layer=end_layer)

    print(f"running sklearn classifiers")
    train_predictions = linear_handler.predict_sklearn(feature_bank, targets_bank)
    print(f"running linear classifier on test")
    test_predictions = linear_handler.predict_sklearn(feature_bank_test, targets_bank_test)
    return train_predictions, test_predictions


def run_dl_lin_classifier_preds(benchmark_model, dataloader_train_kNN, dataloader_test,
                                from_layer, end_layer, max_epochs=100):
    gpus = 1 if torch.cuda.is_available() else 0
    result = {}
    for layer_idx in range(from_layer, end_layer):
        for iter in range(1):
            print(f"running on layer:{layer_idx}, iter:{iter}")
            pl.seed_everything(iter)
            classifier = Classifier(benchmark_model.backbone, layer_idx=layer_idx,
                                    train_backbone=False, num_classes=100).cuda()

            trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus,
                                 progress_bar_refresh_rate=100,
                                 check_val_every_n_epoch=10)
            trainer.fit(
                classifier,
                dataloader_train_kNN,
                dataloader_test
            )

            metrics = {f'{k}_{layer_idx}_{iter}_linear': v for k, v in trainer.logged_metrics.items()}
            result[layer_idx].update(metrics)

    return result


def get_random_targets(rand_targets_folder):
    # epochs = [10, 30, 100, 500, 1000]
    target_files = glob(os.path.join(rand_targets_folder, "*"))
    targets = {}
    for file_path in target_files:
        epoch = int(file_path.split('/')[-1].split('_')[1].split('.')[0])
        # if epoch not in epochs:
        #     continue
        targets[epoch] = json.load(open(file_path, "r"))
    return targets


def load_hparams(output_path):
    hparams_path = f"{output_path}/version_0/hparams.yaml"
    with open(hparams_path, 'r') as config_file:
        return yaml.full_load(config_file)


def run_linear_evaluation(checkpoint_folder, output_path, rand_targets_folder=None):
    from offline_evaluation.load_datasets import load_datasets
    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = \
        load_datasets(augment=False, use_imagenet_transforms=True)

    pl.seed_everything(0)
    hparams = load_hparams(output_path)
    print(f"hparams:{hparams}")



    gpus = 1 if torch.cuda.is_available() else 0
    random_targets = None
    if rand_targets_folder:
        random_targets = get_random_targets(rand_targets_folder)

    checkpoint_files = list(glob(os.path.join(checkpoint_folder, "*")))
    print(f"checkpoint_files:{checkpoint_files}")

    interesting_epoch_list = [19, 39, 99, 399, 999]

    checkpoint_files = checkpoint_files[::-1]
    random.Random(3402).shuffle(checkpoint_files)

    if hparams['model_name'] == "VICReg":
        benchmark_model = VICRegModel(dataloader_kNN=dataloader_train_eval,
                                      dataloader_test=dataloader_test_eval,
                                      num_classes=classes,
                                      resnet_type=hparams['resnet_type'],
                                      depth=hparams['depth'], width=hparams['width']).cuda()
    else:
        benchmark_model = SimCLRModel(dataloader_kNN=dataloader_train_eval,
                                      dataloader_test=dataloader_test_eval,
                                      num_classes=classes,
                                      resnet_type=hparams['resnet_type'],
                                      depth=hparams['depth'], width=hparams['width']).cuda()



    print(f"benchmark_model:{benchmark_model}, device:{benchmark_model.dummy_param.device}")
    from_layer = 0
    to_layer = 5

    for checkpoint_path in tqdm(checkpoint_files, total=len(checkpoint_files)):
        epoch = checkpoint_path.split('/')[-1].split('=')[1].split('-')[0]
        if int(epoch) not in interesting_epoch_list:
            continue

        new_output_path = os.path.join(output_path, f"CHECKPOINT_{epoch}.json")

        if os.path.isfile(new_output_path):
            print(f"SKIPPING PATH: {new_output_path}, exists!")
            continue

        print(f"SSL checkpoint on {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location='cpu')['state_dict']
        benchmark_model.load_state_dict(checkpoint)
        results = {}

        benchmark_model.eval()

        # getting features on both train and test sets
        benchmark_model.training_epoch_end(outputs=None, is_offline=True, from_layer=from_layer, mode='eval')
        benchmark_model.training_epoch_end(outputs=None, is_offline=True, from_layer=from_layer, use_test=True,
                                           mode='eval')

        # Run for GT targets
        train_predictions, test_predictions = run_lstsq_preds(benchmark_model.feature_bank,
                                                              benchmark_model.targets_bank,
                                                              from_layer, to_layer,
                                                              benchmark_model.feature_bank_test,
                                                              benchmark_model.targets_bank_test)
        results['gt'] = {}
        results['gt']['train'] = train_predictions
        results['gt']['test'] = test_predictions

        for key, targets in random_targets.items():
            print(f"random targets: {key}")
            random_epoch = key
            train_targets, test_targets = torch.tensor(targets['train']), torch.tensor(targets['test'])
            cur_targets_train = {idx: train_targets for idx in range(from_layer, to_layer)}
            cur_targets_test = {idx: test_targets for idx in range(from_layer, to_layer)}

            train_predictions, test_predictions = run_lstsq_preds(benchmark_model.feature_bank,
                                                                  cur_targets_train,
                                                                  from_layer, to_layer,
                                                                  benchmark_model.feature_bank_test,
                                                                  cur_targets_test)

            results[random_epoch] = {}
            results[random_epoch]['train'] = train_predictions
            results[random_epoch]['test'] = test_predictions

        print(results)
        json.dump(results, open(new_output_path, 'w'))
        print(f"saved to: {new_output_path}")

