import os
import random
from scipy.io import loadmat
from collections import defaultdict
import json
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
from .multimodal_base_dataset import MultimodalDatum, MultimodalDatasetBase
from dassl.utils import read_json
import pandas as pd
from tqdm import tqdm


@DATASET_REGISTRY.register()
class ESNLIVE(MultimodalDatasetBase):
    dataset_dir = 'esnlive'

    def __init__(self, cfg):
        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = os.path.join(root, self.dataset_dir)

        self._lab2cname = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
        self._cname2lab = {v: k for k, v in self._lab2cname.items()}
        self._classnames = list(self._lab2cname.values())
        self._num_classes = len(self._lab2cname)
        print('final num classes: ', self._num_classes)

        train, val, test = self.read_data()

        # num_shots = cfg.DATASET.NUM_SHOTS
        # train = self.generate_fewshot_dataset(train, num_shots=num_shots)
        # val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))

        self._train_x = train  # labeled training data
        self._train_u = None  # unlabeled training data (optional)
        self._val = val  # validation data (optional)
        self._test = test  # test data

    def read_data(self):
        train, val, test = [], [], []

        for name, split in zip(['train', 'dev', 'test'], [train, val, test]):  # only use 1k samples for test probing

            table = pd.read_csv(f"{self.dataset_dir}/esnlive_{name}.csv")
            table_image_id = table['Flickr30kID'].tolist()
            table_hypothesis = table['hypothesis'].tolist()
            table_cname = table['gold_label'].tolist()

            for image_id, hypothesis, cname in zip(table_image_id,
                                                   table_hypothesis,
                                                   table_cname,
                                                   ):
                impath = os.path.join(f"{self.dataset_dir}/flick30k-images/{image_id}")
                split.append(
                    MultimodalDatum(impath=impath, label=self._cname2lab[cname], classname=cname,
                                    condition=hypothesis))

        return train, val, test
