if __name__=='__main__':
    import sys, os, yaml, re, pickle
    # print()
    print(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from mi_boiler.sigma_conf import load_master_config, getenv_roundup
    from mi_boiler.compression_tools import iterate_over_compact_chunks, save_records_compact
    from tqdm import tqdm, trange
    import numpy as np
    import torch
    from collections import defaultdict
    from transformers import AutoModelForCausalLM, AutoTokenizer

    recompute_logits = False
    fix_batch_lengths = True

    # # can configure the run here if needed
    # conf, printable_conf, oconf = load_master_config('uncertainty/uncertainties_computation.yaml', abspath=True, return_original_conf=True)
    # # unpack into globals
    # globals().update(conf)
    # # print the used environment variables
    # print(getenv_roundup())
    # print(printable_conf)
    prefix_dir = os.environ['PREFIX_DIR']
    grid_size = int(os.environ['GRID_SIZE']) if 'GRID_ID' in os.environ else 0
    grid_id = int(os.environ['GRID_ID']) if 'GRID_ID' in os.environ else 0
    device = os.environ['USE_DEVICE'] 
    pigz_temp_path = os.environ['PIGGZ_TEMP_DIR']
    fix_part = os.environ['FIX_PART'] if 'FIX_PART' in os.environ else 'bs'

    re_striction = os.environ['APPROVED_REGEX'] if 'APPROVED_REGEX' in os.environ else '.*'

    # get all the directories eligible
    dirs = sorted([os.path.join(prefix_dir, d) for d in os.listdir(prefix_dir) if re.fullmatch(re_striction, d)])
    print(dirs)
    # leave only ones assigned to the current grid id
    dirs = [d for i,d in enumerate(dirs) if (i%grid_size)==grid_id]
    print(dirs)

    # now can go over every each applicable dir
    for rdir in dirs:
        results = defaultdict(list) # {k: [] for k in partial_uncert_fns.keys()}

        # load the config from that specific directory
        with open(os.path.join(rdir, 'used_var.sh')) as fp:
            for l in fp.readlines():
                var_name = l[7:].split('=')[0].strip('\n')
                var_val = l[7:].split('=')[1].strip('\n')

                if var_name in ['SAMPLE_TEMPERATURE', 'USE_MODEL_SHORT', 'USE_MODEL_REPO','DATASET_ID']:
                    # set the env variables if one of the interesting vars
                    print(var_name, var_val)
                    os.environ[var_name] = var_val
                    # the rest should be set manually before, e.g. cuda device
        
        # load the configuration of the run to get the model and the tokenizer
        conf, printable_conf, oconf = load_master_config(os.path.join(rdir, 'oconf.yaml'), abspath=True, return_original_conf=True)

        if recompute_logits:
            # get model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(conf['model_id'], **conf['tokenizer_kwargs'])
            model = AutoModelForCausalLM.from_pretrained(conf['model_id'], device_map=device, torch_dtype=torch.bfloat16, **conf['model_kwargs'])

            from generation.gen_utils import probe_token_id, reassign_ids

            stop_on_tokens = [
                probe_token_id(t.replace(r'\\n', "\n"), tokenizer) if isinstance(t, str) else t for t in conf['stop_on_tokens']
                ] # yaml hates newline characters :\K
            print(stop_on_tokens)
            stop_on_tokens = list(set(stop_on_tokens))

        # iterator precomputes the list, can safely modify the files
        for records, chunkpath in iterate_over_compact_chunks(
            rdir,
            range_start=0, 
            range_end=-1, 
            restrict_sets=[fix_part],
            pigz_fast_temp_path=pigz_temp_path,
            load_arrays=True,
            return_file_prefix=True,
        ):
            # small type conversion
            print(f"Normalizing the chunk: ")
            for subset in records.keys():
                for r in records[subset]:
                    for k in r.keys():
                        if k.startswith('txt') and isinstance(r[k], np.ndarray):
                            r[k] = list(r[k])
                    for k, v in r.items():
                        if isinstance(v, np.ndarray):
                            r[k] = torch.from_numpy(r[k])
            
            # compute the logits for the beam search output
            with torch.no_grad():
                if fix_batch_lengths:
                    print("Fixing Sequences LEngths")
                    for data_idx in trange(len(records[fix_part])):
                        # at one point there was an issue with assigning the correct sequence length to a generation
                        # the maximum should be -1 from the sequence length to get some of the things to work
                        seqlens = records[fix_part][data_idx]['sequences_len']
                        toks_out = records[fix_part][data_idx]['sequences_no_input']
                        seqlens[seqlens>=toks_out.shape[-1]] = toks_out.shape[-1]-1
                        records[fix_part][data_idx]['sequences_len'] = seqlens
                if recompute_logits:
                    print("Recomputing logits")
                    for data_idx in trange(len(records[fix_part])):
                        mins = records[fix_part][data_idx]['sequences']
                        mouts = model(input_ids=mins.to(device))
                        # recompute transition probabilities manually for the beam search 
                        transition_scores = torch.zeros(mouts.logits.shape[:2])[:, :-1] - torch.inf
                        logits = mouts.logits.log_softmax(dim=-1)
                        # assign the same added up probability to all stop tokens (fine to do here, since the sampling is not done, just reconstruction)
                        logits[:, :,  stop_on_tokens + [model.config.eos_token_id]] = torch.logsumexp(logits[:, :, stop_on_tokens + [model.config.eos_token_id]], dim=-1).unsqueeze(-1)
                        
                        output_lens = records[fix_part][data_idx]['sequences_len']
                        tok_x_nlls = records[fix_part][data_idx]['tok_x_nlls']
                        # print(transition_scores.shape, logits.shape, mins.shape)

                        for bid in range(transition_scores.shape[0]):
                            for tokid in range(transition_scores.shape[1]):
                                # print(bid,tokid,mins[bid, tokid].item())
                                transition_scores[bid, tokid] = logits[bid, tokid, mins[bid, tokid+1].item()] 
                                #if tokid <= output_lens[bid]+tok_x_nlls.shape[0] else -torch.inf
                        # update the transition probs
                        records[fix_part][data_idx]['seq_no_input_transition'] = transition_scores[:, tok_x_nlls.shape[0]-1:]
            # save records
            save_records_compact(
                records[fix_part],
                chunkpath,
                use_pigz=True,
                pigz_fast_temp_path=pigz_temp_path,
                digitize_conf={'logits': {'table':'lsm'}} if 'logits' in records[fix_part][0] else {}
            )
