import os, re
import torch
import numpy as np
from tqdm import tqdm
from mi_boiler.compression_tools import load_from_compact_chunks

from evaluation.qa_datasets import CoQA, TriviaQA, KUQDataset, COLLIEDataset, SQUADv2
from evaluation.code_data import BCBDataset
from evaluation.multimodal_data import SceneParsingOcclusion
from evaluation.base_dataset import DatasetWithExactSolution, DatasetWithReference, DatasetWithPerturbation, DatasetWithOODInfo
from tqdm import tqdm
import pandas as pd
from multiprocessing import current_process

from .perturbation_text import CoQAPerturb, SQUADPerturb


datasets_map = {
    'COQA': CoQA(),
    'TRIVIA': TriviaQA(),
    'COLLIE': COLLIEDataset(),
    'KUQ': KUQDataset(),
    'BCB': BCBDataset(),
    'SPAR': SceneParsingOcclusion(),
    'SQUAD': SQUADv2(),
    'COQAPERT': CoQAPerturb(perturb_strength=[0., 0.2], perturb_type='shuffle',slim=1024),
    'SQUADPERT': SQUADPerturb(perturb_strength=[0., 0.2], perturb_type='shuffle',slim=1024),
}


def evaluate_correctness_dirs(
    rdir,
    calculate_correctness_on = 'bs',
    use_scores_for_reference_values = ['rouge', 'bleu'],
    use_judges = None,
    starting_dset_id = 0,
    ending_dset_id = -1,
    restrict_sets = None,
    piggz_tmp_dir='.    overwrite=False,
    perform_exact_for_bcb=False,
):
    # figure out tqdm stuff
    try:
        worker_id = current_process()._identity[0]-1
    except:
        worker_id = 0

    exact_accum = []
    judge_accum = []
    ref_accum = []
    ood_accum = []
    perturb_accum = []

    records = load_from_compact_chunks(
        rdir,
        range_start=starting_dset_id, 
        range_end=ending_dset_id, 
        restrict_sets=restrict_sets,
        pigz_fast_temp_path='.        load_arrays=False,          # correctness doesnt need heavy shit like logits, usually
    )

    # small type conversion
    # print(f"Normalizing the chunk: ")
    for subset in ['ms', 'bs']:
        for r in tqdm(
            records[subset],
            desc=f"Normalizing the chunk on {worker_id}",
            position=worker_id+1, 
        ):
            for k in r.keys():
                if k.startswith('txt') and isinstance(r[k], np.ndarray):
                    r[k] = list(r[k])
            for k, v in r.items():
                if isinstance(v, np.ndarray):
                    r[k] = torch.from_numpy(r[k])
    
    dataset_code = records['ms'][0]['dataset']
    print(dataset_code)
    
    ds = datasets_map[dataset_code]

    if overwrite or not os.path.exists(os.path.join(rdir, 'corr_exact.parquet')):
        if isinstance(ds, DatasetWithExactSolution):
            if isinstance(ds, BCBDataset):
                if perform_exact_for_bcb:
                    # use the bcb server to evaluate solutions for all records
                    assert worker_id==0, f"Not going to do exact correctness for BCB in parallel mode!"
                    correctness, _ = ds.calculate_exact_batch_correctness(
                        records['bs'],
                        manual_upload=True,
                    )
                    # print(correctness[:5])
                    exact_accum = correctness
                else:
                    print("Configured to skip the exact solution to BCB...")
            else:
                # collie and some other lightweight stuff evaluate locally
                print("Computing exact solutions.")
                for entry in tqdm(
                    records[calculate_correctness_on],
                    desc=f"Exact solution for {rdir} on {worker_id}",
                    position=worker_id+1, 
                ):
                    entry_dict = ds.calculate_exact_correctness(
                        entry['dataset_idx'], 
                        entry['txt_y'][0]
                    )
                    entry_dict['dataset_idx'] = entry['dataset_idx']
                    exact_accum.append(entry_dict)

    # save the results as parquet files
    if len(exact_accum)>0:
        df = pd.DataFrame.from_records(exact_accum)
        df.to_parquet(os.path.join(rdir, 'corr_exact.parquet'))

    if isinstance(ds, DatasetWithReference):
        if overwrite or not os.path.exists(os.path.join(rdir, 'corr_ref.parquet')):
            if len(use_scores_for_reference_values) > 0:
                print("Computing correctness with reference solution.")
                for entry in tqdm(
                    records[calculate_correctness_on],
                    desc=f"Reference metrics for {rdir} on {worker_id}",
                    position=worker_id+1, 
                ):
                    entry_dict = ds.calculate_correctness(
                        entry['dataset_idx'], 
                        [entry['txt_y'][0]], 
                        scorelist=use_scores_for_reference_values)
                    entry_dict['dataset_idx'] = entry['dataset_idx']
                    ref_accum.append(entry_dict)
        
        if use_judges is not None:
            # determine how many judge files are there and where to start and what to do
            existing_judge_models = set({})
            if not overwrite:
                judge_file_regex = 'corr_judge_[0-9]{1,2}.parquet'
                judge_files = [f for f in os.listdir(rdir) if re.fullmatch(judge_file_regex, f)]
                print(f'Judge models already evaluated: ')
                for jf in sorted(judge_files):
                    existing_judge = pd.read_parquet(os.path.join(rdir, jf))
                    jname = existing_judge._judge_model.iloc[0] if '_judge_model' in existing_judge.columns else existing_judge.judge_model.iloc[0]
                    # if prompt style not listed, qa prompt was used
                    jprompt = existing_judge._prompt_style.iloc[0] if '_prompt_style' in existing_judge.columns else 'qa' 
                    jlen = str(existing_judge._use_max_tokens.iloc[0]) if '_use_max_tokens' in existing_judge.columns else '1' 
                    jtemp = str(existing_judge._temperature.iloc[0]) if '_temperature' in existing_judge.columns else '1.' 
                    existing_judge_models.add(f"{jname}_{jprompt}_{jlen}_{jtemp}") # unique string describing the configuration
                    judge_accum.append(None) # to reserve the index
                    print(f'{jf} : {f"{jname}_{jprompt}_{jlen}_{jtemp}"}')
            if len(use_judges) == 0:
                use_judges = [{}]
            for judge_cfg in use_judges:
                judge_str_id = f"{judge_cfg['use_model']}_{'qa' if judge_cfg.get('prompt_for_qa', True) else 'gen'}_{judge_cfg.get('use_max_tokens', 1)}_{judge_cfg.get('temperature', 1.)}"
                if judge_str_id in existing_judge_models:
                    print(f"Skipping {judge_cfg}, judge already evaluated!")
                    continue
                jjaccum = []
                for entry in tqdm(
                    records[calculate_correctness_on],
                    desc=f"Judge {judge_str_id} for {rdir} on {worker_id}",
                    position=worker_id+1, 
                ):
                    entry_dict = ds.calculate_correctness(
                        entry['dataset_idx'],
                        [entry['txt_y'][0]],
                        scorelist=['judge'],
                        judge=judge_cfg)
                    entry_dict['dataset_idx'] = entry['dataset_idx']
                    jjaccum.append(entry_dict)
                judge_accum.append(jjaccum)
        
    if len(ref_accum)>0:
        df = pd.DataFrame.from_records(ref_accum)
        df.to_parquet(os.path.join(rdir, 'corr_ref.parquet'))

    if len(judge_accum)>0:
        for idx, ja in enumerate(judge_accum):
            if ja is None:
                continue
            df = pd.DataFrame.from_records(ja)
            df.to_parquet(os.path.join(rdir, f'corr_judge_{idx}.parquet'))

    if overwrite or not os.path.exists(os.path.join(rdir, 'id_ood.parquet')):
        if isinstance(ds, DatasetWithOODInfo):
            # print("Getting OOD identifiers.")
            for entry in tqdm(
                records[calculate_correctness_on],
                desc=f"ood info for {rdir} on {worker_id}",
                position=worker_id+1, 
            ):
                entry_dict = {'ood_label': ds.get_ood_identifier(entry['dataset_idx'])}
                entry_dict['dataset_idx'] = entry['dataset_idx']
                ood_accum.append(entry_dict)

    if len(ood_accum)>0:
        df = pd.DataFrame.from_records(ood_accum)
        df.to_parquet(os.path.join(rdir, 'id_ood.parquet'))

    if overwrite or not os.path.exists(os.path.join(rdir, 'id_perturb.parquet')):
        if isinstance(ds, DatasetWithPerturbation):
            # print("Getting Perturbation strength.")
            for entry in tqdm(
                records[calculate_correctness_on],
                desc=f"perturbation strength for {rdir} on {worker_id}",
                position=worker_id+1, 
            ):
                entry_dict = {
                    'group_with_ids': ds.get_perutrbation_group_ids(entry['dataset_idx']),
                    'ood_label': ds.get_perturbation_strength(entry['dataset_idx'])
                }
                entry_dict['dataset_idx'] = entry['dataset_idx']
                perturb_accum.append(entry_dict)
    
    if len(perturb_accum)>0:
        df = pd.DataFrame.from_records(perturb_accum)
        df.to_parquet(os.path.join(rdir, 'id_perturb.parquet'))


if __name__=='__main__':
    import os, sys
    from functools import partial
    print(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from mi_boiler.sigma_conf import load_master_config, getenv_roundup

    # can configure the run here if needed
    conf, printable_conf, oconf = load_master_config('evaluation/correctness_config.yaml', abspath=True, return_original_conf=True)
    # unpack into globals
    globals().update(conf)
    # print the used environment variables
    print(getenv_roundup())
    print(printable_conf)
    
    import os, time
    from evaluation.correctness_evaluation import evaluate_correctness_dirs

    dirs = sorted([os.path.join(prefix_dir, d) for d in os.listdir(prefix_dir) if re.fullmatch(rematch, d) and os.path.isdir(os.path.join(prefix_dir, d))])
    print(dirs)
    time.sleep(5.)

    par_evaluate_cordir = partial(
        evaluate_correctness_dirs,
        **corr_config,
    )

    if num_workers > 1:
        from multiprocessing import Pool
        with Pool(num_workers) as p:
            p.map(par_evaluate_cordir, dirs)
    else:
        for d in dirs:
            par_evaluate_cordir(d)

