import os
import pickle
from sklearn.metrics import f1_score
import numpy as np

dataset_dict = {'ucf101': 'UCF101_embeddings/new_class/', 'imagenet': 'ImageNet_embeddings/new_class/',
                'food101': 'Food101_embeddings/new_class/'
                , 'oxford_flowers': 'OxfordFlowers_embeddings/new_class/',
                'oxford_pets': 'OxfordPets_embeddings/new_class/',
                'eurosat': 'EuroSAT_embeddings/new_class/', 'fgvc_aircraft': 'FGVCAircraft_embeddings/new_class/',
                'sun397': 'SUN397_embeddings/new_class/', 'caltech101': 'Caltech101_embeddings/new_class/',
                'dtd': 'DescribableTextures_embeddings/new_class/',
                'stanford_cars': 'StanfordCars_embeddings/new_class/'}


def evaluate_results(dataset, path, class_type='Base'):
    # Iterate over all 3 seeds
    average_accuracy = 0
    for seed in ['seed1/', 'seed2/', 'seed3/']:
        all_pickle_files = sorted(os.listdir(path + seed + dataset_dict[dataset]))
        embeddings_path = path + seed + dataset_dict[dataset]
        all_pickle_files = all_pickle_files[:-2]
        correct = 0
        total = 0
        y_true = []
        y_pred = []
        results = {}
        for batch_idx, pickle_name in enumerate(all_pickle_files):  # This will run through only each pickle file
            # Also then need to extract both logits and the label from pickle file
            loaded_pickle = pickle.load(open(embeddings_path + pickle_name, 'rb'))
            logits = loaded_pickle['logits']
            label = loaded_pickle['corresponding_labels']
            # mo (torch.Tensor): model output [batch, num_classes]
            # gt (torch.LongTensor): ground truth [batch]
            mo = logits
            gt = label
            pred = mo.max(1)[1]
            matches = pred.eq(gt).float()
            correct += int(matches.sum().item())
            total += gt.shape[0]
            y_true.extend(gt.data.cpu().numpy().tolist())
            y_pred.extend(pred.data.cpu().numpy().tolist())

        acc = 100.0 * correct / total
        err = 100.0 - acc
        macro_f1 = 100.0 * f1_score(
            y_true,
            y_pred,
            average="macro",
            labels=np.unique(y_true)
        )

        # The first value will be returned by trainer.test()
        results["accuracy"] = acc
        results["error_rate"] = err
        results["macro_f1"] = macro_f1
        average_accuracy = average_accuracy + acc

    print(f"{class_type} class accuracy averaged over 3 seeds for {dataset} is {average_accuracy / 3:.2f}%")

