import os
os.environ['HF_HOME']='.
from evaluation.correctness_evaluation import evaluate_correctness_dirs

prefix = 'sample/workdir'
dirs = [os.path.join(prefix, d) for d in os.listdir(prefix)]
print(dirs)

import os
import torch
import numpy as np
from tqdm import tqdm
from mi_boiler.compression_tools import iterate_over_compact_chunks

from evaluation.datasets import CoQA, TriviaQA, KUQDataset, COLLIEDataset
from evaluation.datasets import SQUADv2
from evaluation.multimodal_data import SceneParsingOcclusion
from evaluation.code_data import BCBDataset
from evaluation.base_dataset import DatasetWithExactSolution, DatasetWithReference, DatasetWithPerturbation, DatasetWithOODInfo
from tqdm import tqdm
import pandas as pd
import pickle

datasets_map = {
    'COQA': CoQA(),
    'TRIVIA': TriviaQA(),
    'COLLIE': COLLIEDataset(),
    'KUQ': KUQDataset(),
    'BCB': BCBDataset(),
    'SPAR': SceneParsingOcclusion(),
    'SQUAD': SQUADv2(),
}

record_dirs = dirs

def collect_lengths(rdir):
    # check the output lengths of different 
    results = []

    for records in iterate_over_compact_chunks(
        rdir,
        range_start=0, 
        range_end=-1, 
        restrict_sets=None,
        pigz_fast_temp_path='.        load_arrays = True
    ):
        # small type conversion
        print(f"Normalizing the chunk: ")
        for subset in ['ms', 'bs']:
            for r in tqdm(records[subset]):
                for k in r.keys():
                    if k.startswith('txt') and isinstance(r[k], np.ndarray):
                        r[k] = list(r[k])
                for k, v in r.items():
                    if isinstance(v, np.ndarray):
                        r[k] = torch.from_numpy(r[k])
        
        dataset_code = records['ms'][0]['dataset']
        dset = datasets_map[dataset_code]

        for idx in range(len(records['ms'])):
            # sanity 
            assert records['ms'][idx]['dataset_idx']==records['bs'][idx]['dataset_idx'], "Mismatch!!!!!"
            
            # get the ids
            res = {'dataset_idx': records['ms'][idx]['dataset_idx']}
            
            # get the ref answers len
            if isinstance(dset, DatasetWithReference):
                ref_answers = dset.get_answer(records['ms'][idx]['dataset_idx'])
                if len(ref_answers) == 0:
                    # can happen, e.g. if the ood label is true in trivia
                    res['wlens_ref'] = np.nan
                elif isinstance(ref_answers, list):
                    alens = sum([len(a.split(' ')) for a in ref_answers])/len(ref_answers)
                    res['wlens_ref'] = alens
                elif isinstance(ref_answers, str):
                    res['wlens_ref'] = len(ref_answers.split(' '))
                else:
                    raise ValueError("lolwut?")

            # get the lens of the ms stuff
            res['lens_ms'] = records['ms'][idx]['sequences_len']
            res['lens_bs'] = records['bs'][idx]['sequences_len']
            # get the numbers of words
            res['wlens_ms'] = [len(records['ms'][idx]['txt_y'][j].split(' ')) for j in range(10)]
            res['wlens_bs'] = [len(records['bs'][idx]['txt_y'][j].split(' ')) for j in range(1)]
            results.append(res)
    with open(os.path.join(rdir, 'reslens.pkl'), 'wb') as f:
        pickle.dump(results, f)
    # pd.DataFrame.from_records(results).to_parquet(os.path.join(rdir, 'stat_lens.parquet'))


if __name__=="__main__":
    from multiprocessing import Pool
    with Pool(min([len(dirs), 16])) as p:
        p.map(collect_lengths, dirs)
