# import cProfile

def example_to_prompt(entry, prompter, ds, add_assistant_prompt=True):
    formatted_message = []
    if 'background' in entry:
        formatted_message.append({
            'role': prompter.get_system_role(),
            'message': ds.get_problem_system_instruction()+entry['background']
        })
    else:
        formatted_message.append({
            'role': prompter.get_system_role(),
            'message': ds.get_problem_system_instruction()
        })
    # check if any few shot examples are provided      
    if 'fs_in' in entry and 'fs_out' in entry:
        assert len(entry['fs_in']) == len(entry['fs_out']), f"Few Shot examples must come in pairs!"
        for ein, eout in zip(entry['fs_in'], entry['fs_out']):
            formatted_message.append({
                'role': prompter.get_user_role(),
                'message': ein
            })
            formatted_message.append({
                'role': prompter.get_assistant_role(),
                'message': eout
            })
    
    # add the question
    txt_x_no_prompt = entry['x']
    formatted_message.append({
        'role': prompter.get_user_role(),
        'message': txt_x_no_prompt            
    })
    if add_assistant_prompt:
        formatted_message.append({
            'role': prompter.get_assistant_role(),
            'message': '',
            'last_msg': True,          
        })
    input_text = prompter.build_prompt(formatted_message)
    return txt_x_no_prompt, input_text


from filelock import FileLock, Timeout, SoftFileLock


def wait_for_lock_file_nice(lock_file: FileLock, kickback_immediately=False):
    if kickback_immediately:
        try:
            lock_file.acquire(timeout=0.)
            return True
        except Timeout as te:
            return False
    while True:
        try:
            lock_file.acquire(timeout=120.)
            return True
        except Timeout as te:
            print(F"Failed to obtain a chunk dumping lock in 120 secs, wait more...")


from functools import partial
from typing import Callable
from transformers import LogitsProcessor, LogitsProcessorList
from .gen_utils import LogitsProcessorWithState, ExtraEOSTokenLogitsProcessorWithConstructor, ExtraEOSTokenLogitsProcessor

def prepare_processors(logitproc, **kwargs):
    if isinstance(logitproc, list) or isinstance(logitproc, LogitsProcessorList):
        # recurse into the elements
        return LogitsProcessorList([prepare_processors(lp, **kwargs) for lp in logitproc])
        # return [prepare_processors(lp, **kwargs) for lp in logitproc]
    elif isinstance(logitproc, dict):
        # print(logitproc.keys())
        # just loaded from config, turn into list
        return LogitsProcessorList([prepare_processors(lp, **kwargs) for lp in logitproc.values()])
        # return [prepare_processors(lp, **kwargs) for lp in logitproc.values()]
    elif isinstance(logitproc, Callable) and not isinstance(logitproc, LogitsProcessor):
        # one of the guys wanting the tokenizer and vocab size, no less
        # print(logitproc)
        return logitproc(
            tokenizer = kwargs['tokenizer'],
            vocab_size = kwargs['vocab_size']
        )
    elif isinstance(logitproc, LogitsProcessorWithState):
        # we have one of the stated ones, reset and return 
        logitproc.reset()
        return logitproc
    else:
        return logitproc


def extract_stop_tokens_from_processors(logitproc):
    if isinstance(logitproc, list):
        accum = []
        for lp in logitproc:
            accum += extract_stop_tokens_from_processors(lp)
        return accum
    elif isinstance(logitproc, ExtraEOSTokenLogitsProcessorWithConstructor) or isinstance(logitproc, ExtraEOSTokenLogitsProcessor):
        return logitproc.considered_tokens
    else:
        return []


from vllm import LLM, RequestOutput, SamplingParams
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import BeamSearchOutput

from typing import List, Dict, Tuple, Callable


def pad_list_to_max_len(ls: List[List], pad_value=-1) -> List[List]:
    all_lens = [len(l) for l in ls]
    max_len = max(all_lens)
    padded_ls = [l + [pad_value,]*(max_len-len(l)) for l in ls]
    return padded_ls


def reconstruct_oall_from_vllm_output(
    oall: RequestOutput,
    tokenizer_breadth: int,
    pad_token_id: int,
) -> dict:
    # compute the statistics for the prompt first
    prompt_seltok = oall.prompt_token_ids
    prompt_txt = oall.prompt
    prompt_len = len(prompt_seltok)

    # get prompt selected logits
    prompt_seltok_logits = torch.zeros(size=(1, prompt_len))
    for pos, p_id in enumerate(prompt_seltok):
        logprob_info = oall.prompt_logprobs[pos]
        if logprob_info is not None:
            prompt_seltok_logits[0, pos] = logprob_info[p_id].logprob # assign the probability
    
    # then compute the stats for individual generations
    generated_output_number = len(oall.outputs)
    generated_output_length = [len(o.token_ids) for o in oall.outputs]
    
    # get the generated token ids
    generated_token_ids = [o.token_ids for o in oall.outputs]
    # ref cumprobs for checking aggregation later on
    gen_refcumlogprobs = [o.cumulative_logprob for o in oall.outputs]
    generated_token_text = [o.text for o in oall.outputs]

    # compute the logits into our format
    gen_logits = torch.zeros(size=(generated_output_number, max(generated_output_length), tokenizer_breadth), dtype=torch.bfloat16) - torch.inf
    for o in oall.outputs:
        o_id = o.index
        for pos_id, lps_dict in enumerate(o.logprobs):
            for tid, lp in lps_dict.items():
                gen_logits[o_id, pos_id, tid] = lp.logprob

    
    # compute the transition scores
    genseq_transition_scores = torch.zeros(size=(generated_output_number, max(generated_output_length)), dtype=torch.float32) - torch.nan
    for batch_index, seq in enumerate(oall.outputs): # batch dim
        for sl, (lpdict, token_id) in enumerate(zip(seq.logprobs, generated_token_ids[batch_index])):
            genseq_transition_scores[batch_index, sl] = lpdict[token_id].logprob

    # run an assertion to make sure we are consistently counting everything
    assert torch.allclose(torch.tensor(gen_refcumlogprobs),torch.nansum(genseq_transition_scores, -1), atol=1e-2, rtol=1e-2), f"Discrepancy between aggregated and vllm cummulative probs!"

    
    # convert everything expected to be a tensor into a tensor
    generated_output_length = torch.tensor(generated_output_length, dtype=torch.long)
    
    padded_generated_ids = pad_list_to_max_len(generated_token_ids, pad_value=pad_token_id)
    sequences_no_input = torch.tensor(padded_generated_ids, dtype=torch.long)

    sequences = torch.tensor([prompt_seltok+outseq for outseq in padded_generated_ids], dtype=torch.long)
    sequences_len = torch.tensor(generated_output_length, dtype=torch.long)
    
    prompt_seltok = torch.tensor(prompt_seltok, dtype=torch.long)

    # return back in a way compatible with the old way
    return {
        'logits': gen_logits,
        # token outputs and transition probabilities
        'sequences': sequences,
        'sequences_no_input': sequences_no_input, # ids without prompt
        'sequences_len': sequences_len, # length of each sampled output
        'seq_no_input_transition': genseq_transition_scores, # these are -(selected token nll for corresponfing positions)
        'tok_x': prompt_seltok,
        'tok_x_nlls': prompt_seltok_logits,
        # 'txt_prompt': prompt, !+!+!+!
        'txt_x': prompt_txt,
        # 'txt_x_no_prompt': txt_x_no_prompt,
        'txt_y': generated_token_text,
        'txt_xy': tokenizer.batch_decode(sequences, skip_special_tokens=True),
        'txt_xy_full': tokenizer.batch_decode(sequences, skip_special_tokens=False),
        # backmapping info
        # 'dataset_idx': ds_idx,
        # 'model': model_id,
        # 'dataset': dset_id,
        # 'split': dset_split,
        # sampling parameters
        # 'sampling_type': 'MS' if roll_args['generate_kwargs']['num_beams'] is None else 'BS',
        # 'rollout_config_name': roll_name,
        # 'status': 'good',
    }


def reconstruct_oall_from_vllm_bs_output(
    oall: BeamSearchOutput,
    tokenizer,
    input_ids,
    tokenizer_breadth: int,
    pad_token_id: int,
) -> dict:
    # compute the statistics for the prompt first
    sequences_len = [len(ts.logprobs) for ts in oall.sequences]
    
    # compute the prompt len and split off the prompt
    prompt_len = len(input_ids) # sequences_len[0] - len(oall.sequences[0].tokens)
    prompt_seltok = input_ids # oall.sequences[0].tokens[:prompt_len]
    
    # get the tokens
    sequences = [ts.tokens for ts in oall.sequences]
    sequences_no_input = [ts.tokens[prompt_len:] for ts in oall.sequences]
    seq_no_input_cumprobs = [ts.cum_logprob for ts in oall.sequences]

    # get some detokenized stuff
    txt_prompt = tokenizer.decode(prompt_seltok)
    txt_y = [tokenizer.decode(ts, skip_special_tokens=True) for ts in sequences_no_input]
    txt_xy = [tokenizer.decode(ts, skip_special_tokens=True) for ts in sequences]
    txt_xy_full = [tokenizer.decode(ts, skip_special_tokens=False) for ts in sequences]

    # compute the logits
    genseq_logits = torch.zeros(size=(len(oall.sequences), max(sequences_len), tokenizer_breadth), dtype=torch.float32) - torch.inf
    for batch_index, seq in enumerate(oall.sequences): # batch dim
        for sl, lpdict in enumerate(seq.logprobs):
            for k, lp in lpdict.items():
                genseq_logits[batch_index, sl, k] = lp.logprob

    # compute the transition scores
    genseq_transition_scores = torch.zeros(size=(len(oall.sequences), max(sequences_len)), dtype=torch.float32) - torch.nan
    for batch_index, seq in enumerate(oall.sequences): # batch dim
        for sl, (lpdict, token_id) in enumerate(zip(seq.logprobs, sequences_no_input[batch_index])):
            genseq_transition_scores[batch_index, sl] = lpdict[token_id].logprob

    # run an assertion to make sure we are consistently counting everything
    assert torch.allclose(torch.tensor(seq_no_input_cumprobs),torch.nansum(genseq_transition_scores, -1), atol=1e-2, rtol=1e-2), f"Discrepancy between aggregated and vllm cummulative probs!"

    # convert everything to tensors
    prompt_seltok = torch.tensor(prompt_seltok, dtype=torch.long)
    
    padded_generated_ids = pad_list_to_max_len(sequences_no_input, pad_value=pad_token_id)
    sequences_no_input = torch.tensor(padded_generated_ids, dtype=torch.long)

    padded_all_ids = pad_list_to_max_len(sequences, pad_value=pad_token_id)
    sequences = torch.tensor(padded_all_ids, dtype=torch.long)

    sequences_len = torch.tensor(sequences_len, dtype=torch.long)

    # prompt_seltok = torch.tensor(prompt_seltok, dtype=torch.long)
    return {
        'logits': genseq_logits,
        # token outputs and transition probabilities
        # 'sequences': ret_ids, #  all ids
        'sequences_no_input': sequences_no_input, # ids without prompt
        'sequences_len': sequences_len, # length of each sampled output
        'seq_no_input_transition': genseq_transition_scores, # these are -(selected token nll for corresponfing positions)
        # decoded text versions
        'tok_x': prompt_seltok,
        # 'tok_x_nlls': prompt_seltok_logits,
        'txt_prompt': prompt,
        'txt_x': input_text,
        # 'txt_x_no_prompt': txt_x_no_prompt,
        'txt_y': tokenizer.batch_decode(sequences_no_input, skip_special_tokens=True),
        'txt_xy': tokenizer.batch_decode(sequences, skip_special_tokens=True),
        'txt_xy_full': tokenizer.batch_decode(sequences, skip_special_tokens=False),
    }


if __name__=="__main__":
    import sys, os, yaml, time
    # print()
    print(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from mi_boiler.sigma_conf import load_master_config, getenv_roundup

    # can configure the run here if needed
    master_config_path = os.environ.get('MASTER_CONFIG_PATH', './generation/run_config.yaml')
    conf, printable_conf, oconf = load_master_config(master_config_path, abspath=True, return_original_conf=True)
    # unpack into globals
    globals().update(conf)
    # print the used environment variables
    print(getenv_roundup())
    print(printable_conf)

    # set up hf environment
    import os
    # os.environ['HF_HOME']='.    import pickle
    from datasets import load_dataset, Dataset
    from types import SimpleNamespace as sn
    import torch
    import pandas as pd
    # from .prompts import PROMPT_COLLIE_PHI_INSTRUCT_CUSTOMIZED as prompt
    from .gen_utils import ExtraEOSTokenLogitsProcessor, get_generation_length_ids_from_batch, probe_token_id, reassign_ids
    from mi_boiler.compression_tools import save_records_compact
    from mi_boiler.swiss_army import any_to_device_nested
    from tqdm import trange
    
    # load the model
    from transformers import AutoTokenizer

    
    # load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # initialize the vLLM model
    model = LLM(
        model=model_id,
        **model_kwargs,
        **hardware_config,
    )

    # results
    reslists = {k: [] for k in rollouts.keys()}
    torch.set_grad_enabled(False)
    last_checkpoint = dset_range_start
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # save the configuration used
    used_vars = getenv_roundup(clear=True)
    used_vars_repr = '\n'.join([f'export {k}={v}' for k, v in used_vars.items()])
    with open(os.path.join(save_dir, 'used_var.sh'), 'wt') as f:
        f.write(used_vars_repr)
    with open(os.path.join(save_dir, 'oconf.yaml'), 'wt') as f:
        yaml.safe_dump(oconf, f)

    ## initialize the lock file in the temp dir for pigz
    # TODO: a bug waiting to happen, or not...
    dump_file_lock = SoftFileLock(pigz_temp_storage+'.dump_lock', timeout=120.)
    if dump_file_lock.is_locked:
        dump_file_lock.release()
    

    for ds_idx in (tqdm_range_pbar := trange(dset_range_start, dset_range_end if dset_range_end>=0 else len(ds))):
        last_update_time = time.time() # to account for saves and such
        entry = ds[ds_idx]
        # TODO: super crutchy, might not hold, would make sense to make a prompt class e.g.
        # prompter :PromptBuilder, construct a prompt
        txt_x_no_prompt, input_text = example_to_prompt(entry, prompter, ds)
        prompt = ds.get_problem_system_instruction()
        # print(txt_x_no_prompt, input_text)

        if 'X' in model_id and 'x' in model_id:
            torch.cuda.set_device(use_device)
        toks = tokenizer([input_text,]*1, add_special_tokens=True, return_tensors='pt').to(device=use_device)
       

        for roll_name, roll_args in rollouts.items():
            # 
            if 'stop_token_ids' in roll_args['generate_kwargs'].keys() and isinstance(roll_args['generate_kwargs']['stop_token_ids'], partial):
                # run with the tokenizer
                roll_args['generate_kwargs']['stop_token_ids'] = roll_args['generate_kwargs']['stop_token_ids'](tokenizer)

            # tokenize and batch the required number of replicates of the sequences
            toks = tokenizer(input_text, add_special_tokens=True)['input_ids'] # list
            # print(toks)
            
            # generate based on the sequences
            if roll_args['beam_search']:
                outputs_all = model.beam_search(
                    prompts=[{'prompt_token_ids': toks}],
                    params=BeamSearchParams(
                        **roll_args['generate_kwargs']
                    )    
                )
                retdict_from_vllm = reconstruct_oall_from_vllm_bs_output(
                    outputs_all[0],
                    tokenizer,
                    toks,
                    len(tokenizer.vocab),
                    pad_token_id=tokenizer.pad_token_id
                )
            else:
                outputs_all = model.generate(
                    use_tqdm=False,
                    prompt_token_ids=toks,
                    sampling_params=SamplingParams(
                        **roll_args['generate_kwargs']
                    )    
                )
                retdict_from_vllm = reconstruct_oall_from_vllm_output(
                    outputs_all[0],
                    len(tokenizer.vocab),
                    pad_token_id=tokenizer.pad_token_id
                )
            # optionally save logtis
            # if roll_args['store_logits']:
            #     local_retdict.update({
            #         'logits': ret_logits
            #     })
            # else:
            if 'logits' in retdict_from_vllm.keys():
                # do the log softmax just in case
                retdict_from_vllm['logits'] = torch.log_softmax(retdict_from_vllm['logits'], dim=-1)

            # optionally store model hidden states
            if roll_args['store_model_state']:
                raise NotImplementedError(f"Can't do model state with vllm just yet")
                local_retdict.update({
                    'hidden_at_x': ret_hs_at_x,
                    'hidden_at_xy': ret_hs_at_xy,
                    'all_hidden_states': all_hidden,
                })        

            retdict_from_vllm.update({
                # backmapping info
                'dataset_idx': ds_idx,
                'model': model_id,
                'dataset': dset_id,
                'split': dset_split,
                # sampling parameters
                'sampling_type': 'BS' if roll_args['beam_search'] else 'MS',
                'rollout_config_name': roll_name,
                'status': 'good',
            })
            # at this point we are roughly compatible with the standard format
            reslists[roll_name].append(retdict_from_vllm)

            batch_lens = retdict_from_vllm['sequences_len']

            tqdm_range_pbar.set_description(
                desc=f"{roll_name}: {batch_lens.sum().item():04d} tokens / {time.time()-last_update_time:.2f} sec -> {batch_lens.sum().item()/(time.time()-last_update_time):.2f} tps",
                refresh=True
            )
            last_update_time = time.time()

        
        if (ds_idx+1-dset_range_start)%save_every == 0 and save_every > 0:
            wait_for_lock_file_nice(dump_file_lock)
            # with cProfile.Profile() as pr:
            try:
                for roll_name, reslist in reslists.items():
                    save_records_compact(
                        reslist, 
                        f"{save_dir}/{roll_name}_{last_checkpoint:05d}_{ds_idx:05d}", 
                        digitize_conf={'logits': {'table':'lsm'}},
                        use_pigz=True,
                        pigz_fast_temp_path=pigz_temp_storage
                    )
                    reslists[roll_name] = []
                last_checkpoint=ds_idx+1
            except Exception as e:
                raise e
            finally:                
                if dump_file_lock.is_locked:
                    dump_file_lock.release()
                # pr.print_stats(sort='cumtime')
                # exit(0)

    # final piece
    wait_for_lock_file_nice(dump_file_lock)
    try:
        for roll_name, reslist in reslists.items():
            if len(reslist) > 0:
                save_records_compact(
                    reslist, 
                    f"{save_dir}/{roll_name}_{last_checkpoint:05d}_{ds_idx:05d}", 
                    digitize_conf={'logits': {'table':'lsm'}},
                    use_pigz=True,
                    pigz_fast_temp_path=pigz_temp_storage
                )
                reslists[roll_name] = []
        last_checkpoint=ds_idx+1
    except Exception as e:
        raise e
    finally:
        if dump_file_lock.is_locked:
            dump_file_lock.release()
    last_checkpoint=ds_idx

    print("Great succ!")