from models.classifier import Classifier
from models.classifier_vit import ClassifierViT
from models.ssl_models import VICRegModel
import torch
import os
from glob import glob
import random
from tqdm import tqdm
import json
import numpy as np

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

def get_model_predictions(backbone, dataloader_train_kNN, dataloader_test, checkpoint_folder,
                          output_path):

    gpus = 1 if torch.cuda.is_available() else 0
    checkpoint_files = list(glob(os.path.join(checkpoint_folder, "*")))
    random.shuffle(checkpoint_files)

    output_path = os.path.join(output_path)
    if not os.path.isdir(output_path):
        os.mkdir(output_path)

    use_vit = False

    if use_vit:
        depth = 10
        head = 8
        mlp_hidden = 384
        hidden = 384
        classifier = ClassifierViT(num_classes=100,
                                   depth=depth, head=head, mlp_hidden=mlp_hidden,
                                   hidden=hidden).cuda()
    else:
        classifier = Classifier(backbone, layer_idx=-1,
                            num_classes=100)



    for checkpoint_path in tqdm(checkpoint_files, total=len(checkpoint_files)):
        epoch = checkpoint_path.split('/')[-1].split('=')[1].split('-')[0]
        print(f"running on {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location='cpu')['state_dict']
        classifier.load_state_dict(checkpoint)
        classifier.eval()

        train_preds, test_preds = [], []
        for batch in tqdm(dataloader_train_kNN, total=len(dataloader_train_kNN)):
            # print(len(batch))
            data, targets, index = batch
            y_hat = classifier(data.cuda())
            y_hat = torch.argmax(y_hat, dim=1)
            train_preds.extend(list(y_hat.cpu().detach().numpy()))

        for batch in tqdm(dataloader_test, total=len(dataloader_test)):
            # print(batch)
            data, targets, index = batch
            y_hat = classifier(data.cuda())
            y_hat = torch.argmax(y_hat, dim=1)
            test_preds.extend(list(y_hat.cpu().detach().numpy()))

        results = {'train': train_preds,
                   'test': test_preds}
        print(f"epoch:{epoch}, results:{results}")
        new_output_path = os.path.join(output_path, f"CHECKPOINT_{epoch}.json")
        json.dump(results, open(new_output_path, 'w'), cls=NpEncoder)


def run_model_predictions(checkpoint_folder, output_path):
    from offline_evaluation.load_datasets import load_datasets
    dataloader_train_ssl, dataloader_train_eval, dataloader_test_eval = load_datasets(augment=True)
    benchmark_model = VICRegModel(dataloader_kNN=dataloader_train_eval, dataloader_test=dataloader_test_eval,
                                  num_classes=100, resnet_type="resnet-18").cuda()

    backbone = benchmark_model.backbone
    get_model_predictions(backbone=backbone, dataloader_train_kNN=dataloader_train_eval,
                         dataloader_test=dataloader_test_eval, checkpoint_folder=checkpoint_folder,
                         output_path=output_path)