if __name__=="__main__":
    import sys, os, yaml
    # print()
    print(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from mi_boiler.sigma_conf import load_master_config, getenv_roundup

    # can configure the run here if needed
    conf, printable_conf, oconf = load_master_config('./generation/vlm_run_config.yaml', abspath=True, return_original_conf=True)
    # unpack into globals
    globals().update(conf)
    # print the used environment variables
    print(getenv_roundup())
    print(printable_conf)

    # set up hf environment
    import os
    # os.environ['HF_HOME']='.    import pickle
    from datasets import load_dataset, Dataset
    from types import SimpleNamespace as sn
    import torch
    import pandas as pd
    # from .prompts import PROMPT_COLLIE_PHI_INSTRUCT_CUSTOMIZED as prompt
    from .gen_utils import ExtraEOSTokenLogitsProcessor, get_generation_length_ids_from_batch, probe_token_id, reassign_ids
    from .gen_utils import stack_and_pad_inputs, pad_left
    from mi_boiler.compression_tools import save_records_compact
    from mi_boiler.swiss_army import any_to_device_nested
    from tqdm import trange
    
    # load the model
    from transformers import AutoProcessor, AutoModelForCausalLM, BatchFeature
    # model_id = 'microsoft/Phi-3.5-mini-instruct'
    processor = AutoProcessor.from_pretrained(model_id, **tokenizer_kwargs)
    model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2", device_map=use_device, torch_dtype=torch.bfloat16, **model_kwargs)

    # process the stop on token list
    # and also solve the multiple stop token issues
    stop_on_tokens = [
        probe_token_id(t.replace(r'\\n', "\n"), processor.tokenizer) if isinstance(t, str) else t for t in stop_on_tokens
        ] # yaml hates newline characters :\K
    print(stop_on_tokens)
    stop_on_tokens = list(set(stop_on_tokens))
    print(stop_on_tokens)

    # results
    reslist_ms = []
    reslist_bs = []
    torch.set_grad_enabled(False)
    last_checkpoint = dset_range_start
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # save the configuration used
    used_vars = getenv_roundup(clear=True)
    used_vars_repr = '\n'.join([f'export {k}={v}' for k, v in used_vars.items()])
    with open(os.path.join(save_dir, 'used_var.sh'), 'wt') as f:
        f.write(used_vars_repr)
    with open(os.path.join(save_dir, 'oconf.yaml'), 'wt') as f:
        yaml.safe_dump(oconf, f)
    
    for ds_idx in trange(dset_range_start, dset_range_end if dset_range_end>=0 else len(ds)):
        entry = ds[ds_idx]
        images = [entry['image'],]

        listof_inputs: list[BatchFeature] = []
        for messages, image in zip([messages for _ in range(n_samples)], [images for _ in range(n_samples)]):
            prompt = processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            # print(image)
            inputs = processor(prompt, image, return_tensors="pt").to(use_device)
            listof_inputs.append(inputs)
        input_text = prompt
        # print(input_text)
        # print(inputs)

        # try:
        #     toks = tokenizer([input_text,]*1, add_special_tokens=True, return_tensors='pt').to(device=use_device)
        #     prompt_seq = toks.input_ids[0].cpu()
        #     oall = model(**toks)
        #     prompt_logits = oall.logits[0].cpu()
        #     prompt_seltok_logits = torch.stack([torch.log_softmax(o, -1)[ids] for o, ids in zip(prompt_logits, prompt_seq)]).cpu()
        # except Exception as e:
        #     print(f"Failed to compute logits for the prompt, error: {e}")
        prompt_seq = listof_inputs[0].input_ids[0].cpu()
        txt_x_no_prompt = messages[-1]['content']
        prompt_seltok_logits = None

        # tokenize the input
        try:
            # toks = tokenizer([input_text,]*n_samples, add_special_tokens=True, return_tensors='pt').to(device=use_device)            
            toks = stack_and_pad_inputs(
                listof_inputs, pad_token_id=processor.tokenizer.pad_token_id,
                device=use_device,
            )
            
            outputs_all = model.generate(
                **toks, 
                # max_new_tokens=64, 
                do_sample=True, 
                num_beams=1, 
                # top_p=0.95,
                # no_repeat_ngram_size=4,
                # top_k=10,
                # temperature=1.,
                # repetition_penalty=1.
                output_hidden_states=True,
                # output_logits=True,
                output_scores=True,
                return_dict_in_generate=True,
                # stopping_criteria=additional_stopping_crits
                # num_return_sequense=n_samples,
                logits_processor=[
                    # NoBadWordsLogitsProcessor(bad_words_ids=[[13]], eos_token_id=model.config.eos_token_id),
                    ExtraEOSTokenLogitsProcessor(stop_on_tokens, eos_token_id=model.config.eos_token_id)
                ],
                **sampling_kwargs,
            )
            # scores are logits after all of the logits processing (those are some fat matrices, log sm them on the gpu)
            ret_logits = [torch.log_softmax(ol, dim=-1).detach().cpu() for ol in outputs_all.scores]
            # now put the outputs to the cpu
            outputs_all = sn(**any_to_device_nested(outputs_all, device='cpu'))
            ret_ids = outputs_all.sequences
            # how the hidden states are returned
            # 1. generated tokens; 2. layers; 3. [batch_size, n_tokens_processed, dim]
            ret_hs_at_x = outputs_all.hidden_states[0][-1][0,-1,:] # save just the last layer state of the prompt+input
            
            # now the other part is trickier, since sequences do not have to be the same length
            input_prompt_len = toks.input_ids.shape[-1]
            # since the prompt may contain EOS for some models, have to scan only through the generated token
            batch_lens = get_generation_length_ids_from_batch(outputs_all.sequences[:, input_prompt_len:], eos_token_id=model.config.eos_token_id)
            # batch_lens = batch_lens # get the index of the last non eos generated token
            # try:
            ret_hs_at_xy = [
                # the very first entry is the prompt+input # TODO: not 100% sure -1 is necessary
                outputs_all.hidden_states[ltid-1][-1][sample_id,-1,:] for sample_id, ltid in enumerate(batch_lens)
            ] # save all of the final states for when the generation is done
            # except Exception as e:
            #     print('____!____!____!____!')
            #     print(batch_lens, input_prompt_len, ds_idx)
            #     raise e
            ret_hs_at_xy = torch.stack(ret_hs_at_xy, dim=0)
            # ends up being [n_samples, hidden_size]
            # and to top it off save the query and the resulting text separately
            full_x_query = input_text
            full_xy_response = outputs_all.sequences
            # optionally return all hiddens (not needed most of the time, configured to not be done by default)
            if return_all_hidden:
                ms_all_hidden = outputs_all.hidden_states
            else:
                ms_all_hidden = None
            # compute transition probs from beam search output (pain in the ass, usually)
            ms_transition_scores = model.compute_transition_scores(
                outputs_all.sequences, ret_logits, normalize_logits=False
            )
            # skip the image tokens in the outputs for some purposes
            output_sequences_decode_safe = outputs_all.sequences.clone()
            output_sequences_decode_safe[output_sequences_decode_safe<0] = model.config.eos_token_id
            # save all to a dict
            reslist_ms.append({
                'logits': ret_logits, # are provided only for the genrated part
                'sequences': ret_ids, #  all ids
                'sequences_no_input': ret_ids[:, input_prompt_len:], # ids without prompt
                'sequences_len': batch_lens, # length of each sampled output
                'seq_no_input_transition': ms_transition_scores, # these are -(selected token nll for corresponfing positions)
                'hidden_at_x': ret_hs_at_x,
                'hidden_at_xy': ret_hs_at_xy,
                'all_hidden_states': ms_all_hidden,
                'tok_x': prompt_seq,
                'tok_x_nlls': prompt_seltok_logits,
                'txt_prompt': prompt,
                'txt_x': input_text,
                'txt_x_no_prompt': txt_x_no_prompt,
                'txt_y': processor.tokenizer.batch_decode(outputs_all.sequences[:, input_prompt_len:], skip_special_tokens=True),
                'txt_xy': processor.tokenizer.batch_decode(output_sequences_decode_safe, skip_special_tokens=True),
                'txt_xy_full': processor.tokenizer.batch_decode(output_sequences_decode_safe, skip_special_tokens=False),
                'dataset_idx': ds_idx,
                'model': model_id,
                'dataset': dset_id,
                'split': dset_split,
                'sampling_type': 'MS',
                'status': 'good'
            })
        except Exception as e:
            print(f"Error occurred at idx {ds_idx}: {e}")
            reslist_ms.append({
                'logits': None, # are provided only for the genrated part
                'sequences': None, #  all ids
                'sequences_no_input': None, # ids without prompt
                'sequences_len': None, # length of each sampled output
                'seq_no_input_transition': None,
                'hidden_at_x': None,
                'hidden_at_xy': None,
                'all_hidden_states': None,
                'tok_x': prompt_seq,
                'tok_x_nlls': prompt_seltok_logits,
                'txt_prompt': prompt,
                'txt_x': input_text,
                'txt_x_no_prompt': txt_x_no_prompt,
                'txt_y': None,
                'txt_xy': None,
                'txt_xy_full': None,
                'dataset_idx': ds_idx,
                'model': model_id,
                'dataset': dset_id,
                'split': dset_split,
                'sampling_type': 'MS',
                'status': f"Exception: {e}"
            })

        if not ok_beamer:
            try:
                # now perform the beam search
                prompt = processor.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
                toks = processor(prompt, image, return_tensors="pt").to(use_device)

                outputs_all = model.generate(
                    **toks, 
                    # max_new_tokens=64, 
                    do_sample=False, 
                    # num_beams=10, 
                    top_p=None,
                    # no_repeat_ngram_size=4,
                    top_k=None,
                    # temperature=1.,
                    # repetition_penalty=1.
                    output_hidden_states=True,
                    # output_logits=True,
                    output_scores=True,
                    return_dict_in_generate=True,
                    # stopping_criteria=additional_stopping_crits # those don't stop beam_search most of the time
                    logits_processor=[
                        # NoBadWordsLogitsProcessor(bad_words_ids=[[13]], eos_token_id=model.config.eos_token_id),
                        ExtraEOSTokenLogitsProcessor(stop_on_tokens, eos_token_id=model.config.eos_token_id)
                    ],
                    **bs_kwargs
                )
                # scores are logits after all of the logits processing (those are some fat matrices, log sm them on the gpu)
                ret_logits = [torch.log_softmax(ol, dim=-1).detach().cpu() for ol in outputs_all.scores]
                # now put the outputs to the cpu
                outputs_all = sn(**any_to_device_nested(outputs_all, device='cpu'))
                ret_ids = outputs_all.sequences
                # how the hidden states are returned
                # 1. generated tokens; 2. layers; 3. [batch_size, n_tokens_processed, dim]
                ret_hs_at_x = outputs_all.hidden_states[0][-1][0,-1,:] # save just the last layer state of the prompt+input
                
                # now the other part is trickier, since sequences do not have to be the same length
                input_prompt_len = toks.input_ids.shape[-1]
                # since the prompt may contain EOS for some models, have to scan only through the generated token
                batch_lens = get_generation_length_ids_from_batch(outputs_all.sequences[:, input_prompt_len:], eos_token_id=model.config.eos_token_id)
                # batch_lens = batch_lens # get the index of the last non eos generated token
                ret_hs_at_xy = [
                    # the very first entry is the prompt+input
                    outputs_all.hidden_states[ltid-1][-1][sample_id,-1,:] for sample_id, ltid in enumerate(batch_lens)
                ] # save all of the final states for when the generation is done
                ret_hs_at_xy = torch.stack(ret_hs_at_xy, dim=0)
                # ends up being [n_samples, hidden_size]
                # optionally return all hiddens (not needed most of the time, configured to not be done by default)
                if return_all_hidden:
                    bs_all_hidden = outputs_all.hidden_states
                else:
                    bs_all_hidden = None
                # and to top it off save hte query and the resulting text separately
                full_x_query = input_text
                full_xy_response = outputs_all.sequences
                
                # compute transition probs from beam search output (pain in the ass, usually)
                # also seems like beam_search() has extra logic to insert the first thing from the generation config 
                # try to fix that with reassigning the ids to the eos id expected by the logits processor
                bs_transition_scores = model.compute_transition_scores(
                    reassign_ids(outputs_all.sequences, stop_on_tokens, model.config.eos_token_id, starting_from=input_prompt_len), 
                    ret_logits,
                    outputs_all.beam_indices, 
                    normalize_logits=False
                )

                # skip the image tokens in the outputs for some purposes
                output_sequences_decode_safe = outputs_all.sequences.clone()
                output_sequences_decode_safe[output_sequences_decode_safe<0] = model.config.eos_token_id
                # save all to a dict
                reslist_bs.append({
                    'logits': ret_logits, # are provided only for the genrated part
                    'sequences': ret_ids, #  all ids
                    'sequences_no_input': ret_ids[:, input_prompt_len:], # ids without prompt
                    'seq_no_input_transition': bs_transition_scores, # should be [n_seqs, seq_len]
                    'sequences_len': batch_lens, # length of each sampled output
                    'hidden_at_x': ret_hs_at_x,
                    'hidden_at_xy': ret_hs_at_xy,
                    'all_hidden_states': bs_all_hidden,
                    'tok_x': prompt_seq,
                    'tok_x_nlls': prompt_seltok_logits,
                    'txt_prompt': prompt,
                    'txt_x': input_text,
                    'txt_x_no_prompt': txt_x_no_prompt,
                    'txt_y': processor.tokenizer.batch_decode(outputs_all.sequences[:, input_prompt_len:], skip_special_tokens=True),
                    'txt_xy': processor.tokenizer.batch_decode(output_sequences_decode_safe, skip_special_tokens=True),
                    'txt_xy_full': processor.tokenizer.batch_decode(output_sequences_decode_safe, skip_special_tokens=False),
                    'dataset_idx': ds_idx,
                    'model': model_id,
                    'dataset': dset_id,
                    'split': dset_split,
                    'sampling_type': 'BS',
                    'status': 'good'
                })
                if not return_logits_for_bs:
                    del reslist_bs[-1]['logits']
            except Exception as e:
                print(f"Error occurred at idx {ds_idx}: {e}")
                reslist_bs.append({
                    'logits': None, # are provided only for the genrated part
                    'sequences': None, #  all ids
                    'sequences_no_input': None, # ids without prompt
                    'sequences_len': None, # length of each sampled output
                    'seq_no_input_transition': None,
                    'hidden_at_x': None,
                    'hidden_at_xy': None,
                    'all_hidden_states': None,
                    'tok_x': prompt_seq,
                    'tok_x_nlls': prompt_seltok_logits,
                    'txt_prompt': prompt,
                    'txt_x': input_text,
                    'txt_x_no_prompt': txt_x_no_prompt,
                    'txt_y': None,
                    'txt_xy': None,
                    'txt_xy_full': None,
                    'dataset_idx': ds_idx,
                    'model': model_id,
                    'dataset': dset_id,
                    'split': dset_split,
                    'sampling_type': 'MS',
                    'status': f"Exception: {e}"
                })
                if not return_logits_for_bs:
                    del reslist_bs[-1]['logits']
        
        if (ds_idx+1-dset_range_start)%save_every == 0 and save_every > 0:
            # with cProfile.Profile() as pr:
            save_records_compact(
                reslist_ms, 
                f"{save_dir}/ms_{last_checkpoint:05d}_{ds_idx:05d}", 
                digitize_conf={'logits': {'table':'lsm'}},
                use_pigz=True,
                pigz_fast_temp_path=pigz_temp_storage
            )
            if not ok_beamer:
                save_records_compact(
                    reslist_bs, 
                    f"{save_dir}/bs_{last_checkpoint:05d}_{ds_idx:05d}", 
                    digitize_conf={'logits': {'table':'lsm'}} if return_logits_for_bs else {},
                    use_pigz=True,
                    pigz_fast_temp_path=pigz_temp_storage
                    )
            reslist_ms = []
            reslist_bs = []
            last_checkpoint=ds_idx+1
            # pr.print_stats(sort='cumtime')
                # exit(0)

    # final piece
    if len(reslist_ms) > 0:
        save_records_compact(
            reslist_ms, 
            f"{save_dir}/ms_{last_checkpoint:05d}_{ds_idx:05d}", 
            digitize_conf={'logits': {'table':'lsm'}},
            use_pigz=True,
            pigz_fast_temp_path=pigz_temp_storage
        )
        if not ok_beamer:
            save_records_compact(
                reslist_bs, 
                f"{save_dir}/bs_{last_checkpoint:05d}_{ds_idx:05d}", 
                digitize_conf={'logits': {'table':'lsm'}} if return_logits_for_bs else {},
                use_pigz=True,
                pigz_fast_temp_path=pigz_temp_storage
            )
    last_checkpoint=ds_idx

    print("Great succ!")