# import cProfile

def example_to_prompt(entry, prompter, ds, add_assistant_prompt=True):
    formatted_message = []
    if 'background' in entry:
        formatted_message.append({
            'role': prompter.get_system_role(),
            'message': ds.get_problem_system_instruction()+entry['background']
        })
    else:
        formatted_message.append({
            'role': prompter.get_system_role(),
            'message': ds.get_problem_system_instruction()
        })
    # check if any few shot examples are provided      
    if 'fs_in' in entry and 'fs_out' in entry:
        assert len(entry['fs_in']) == len(entry['fs_out']), f"Few Shot examples must come in pairs!"
        for ein, eout in zip(entry['fs_in'], entry['fs_out']):
            formatted_message.append({
                'role': prompter.get_user_role(),
                'message': ein
            })
            formatted_message.append({
                'role': prompter.get_assistant_role(),
                'message': eout
            })
    
    # add the question
    txt_x_no_prompt = entry['x']
    formatted_message.append({
        'role': prompter.get_user_role(),
        'message': txt_x_no_prompt            
    })
    if add_assistant_prompt:
        formatted_message.append({
            'role': prompter.get_assistant_role(),
            'message': '',
            'last_msg': True,          
        })
    input_text = prompter.build_prompt(formatted_message)
    return txt_x_no_prompt, input_text


from filelock import FileLock, Timeout, SoftFileLock


def wait_for_lock_file_nice(lock_file: FileLock, kickback_immediately=False):
    if kickback_immediately:
        try:
            lock_file.acquire(timeout=0.)
            return True
        except Timeout as te:
            return False
    while True:
        try:
            lock_file.acquire(timeout=120.)
            return True
        except Timeout as te:
            print(F"Failed to obtain a chunk dumping lock in 120 secs, wait more...")


from functools import partial
from typing import Callable
from transformers import LogitsProcessor, LogitsProcessorList
from .gen_utils import LogitsProcessorWithState, ExtraEOSTokenLogitsProcessorWithConstructor, ExtraEOSTokenLogitsProcessor

def prepare_processors(logitproc, **kwargs):
    if isinstance(logitproc, list) or isinstance(logitproc, LogitsProcessorList):
        # recurse into the elements
        return LogitsProcessorList([prepare_processors(lp, **kwargs) for lp in logitproc])
        # return [prepare_processors(lp, **kwargs) for lp in logitproc]
    elif isinstance(logitproc, dict):
        # print(logitproc.keys())
        # just loaded from config, turn into list
        return LogitsProcessorList([prepare_processors(lp, **kwargs) for lp in logitproc.values()])
        # return [prepare_processors(lp, **kwargs) for lp in logitproc.values()]
    elif isinstance(logitproc, Callable) and not isinstance(logitproc, LogitsProcessor):
        # one of the guys wanting the tokenizer and vocab size, no less
        # print(logitproc)
        return logitproc(
            tokenizer = kwargs['tokenizer'],
            vocab_size = kwargs['vocab_size']
        )
    elif isinstance(logitproc, LogitsProcessorWithState):
        # we have one of the stated ones, reset and return 
        logitproc.reset()
        return logitproc
    else:
        return logitproc


def extract_stop_tokens_from_processors(logitproc):
    if isinstance(logitproc, list):
        accum = []
        for lp in logitproc:
            accum += extract_stop_tokens_from_processors(lp)
        return accum
    elif isinstance(logitproc, ExtraEOSTokenLogitsProcessorWithConstructor) or isinstance(logitproc, ExtraEOSTokenLogitsProcessor):
        return logitproc.considered_tokens
    else:
        return []


if __name__=="__main__":
    import sys, os, yaml, time
    # print()
    print(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from mi_boiler.sigma_conf import load_master_config, getenv_roundup

    # can configure the run here if needed
    master_config_path = os.environ.get('MASTER_CONFIG_PATH', './generation/run_config.yaml')
    conf, printable_conf, oconf = load_master_config(master_config_path, abspath=True, return_original_conf=True)
    # unpack into globals
    globals().update(conf)
    # print the used environment variables
    print(getenv_roundup())
    print(printable_conf)

    # set up hf environment
    import os
    # os.environ['HF_HOME']='.    import pickle
    from datasets import load_dataset, Dataset
    from types import SimpleNamespace as sn
    import torch
    import pandas as pd
    # from .prompts import PROMPT_COLLIE_PHI_INSTRUCT_CUSTOMIZED as prompt
    from .gen_utils import ExtraEOSTokenLogitsProcessor, get_generation_length_ids_from_batch, probe_token_id, reassign_ids
    from mi_boiler.compression_tools import save_records_compact
    from mi_boiler.swiss_army import any_to_device_nested
    from tqdm import trange
    
    # load the model
    from transformers import AutoTokenizer, AutoModelForCausalLM
    # model_id = 'microsoft/Phi-3.5-mini-instruct'
    tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs)
    if not 'torch_dtype' in model_kwargs:
        model_kwargs['torch_dtype'] = torch.bfloat16
    if not 'device_map' in model_kwargs:
        model_kwargs['device_map'] = use_device
    model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)

    # check if we have to load any adpaters
    if globals().get('model_adapters'):
        adapter_config = globals().get('model_adapters')
        adapter_name = adapter_config.pop('name')
        model.load_adapter(adapter_name, **adapter_config)

    # solve the pad token id stupidity for llama3
    if tokenizer.pad_token_id is None:
        # TODO: doesnt seem to solve it somehow, although assignment goes through. Possibly needs to be updated elsewhere too.
        print('Pad token id stupidity rectified.')
        # tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    print(tokenizer.pad_token_id)
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.eos_token_id
    if model.generation_config.pad_token_id is None:
        model.generation_config.pad_token_id = tokenizer.eos_token_id
        model.generation_config.padding_side = "left"
    # TODO: the padding_side may need to be fixed in config, although we only feed in identical sequence
    # so this should not be causing problems at least here
    print(model.generation_config.pad_token_id)

    # process the stop on token list
    # and also solve the multiple stop token issues
    stop_on_tokens = None # make configurable if used

    # results
    reslists = {k: [] for k in rollouts.keys()}
    torch.set_grad_enabled(False)
    last_checkpoint = dset_range_start
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # save the configuration used
    used_vars = getenv_roundup(clear=True)
    used_vars_repr = '\n'.join([f'export {k}={v}' for k, v in used_vars.items()])
    with open(os.path.join(save_dir, 'used_var.sh'), 'wt') as f:
        f.write(used_vars_repr)
    with open(os.path.join(save_dir, 'oconf.yaml'), 'wt') as f:
        yaml.safe_dump(oconf, f)

    ## initialize the lock file in the temp dir for pigz
    # TODO: a bug waiting to happen, or not...
    dump_file_lock = SoftFileLock(pigz_temp_storage+'.dump_lock', timeout=120.)
    if dump_file_lock.is_locked:
        dump_file_lock.release()
    
    for ds_idx in (tqdm_range_pbar := trange(dset_range_start, dset_range_end if dset_range_end>=0 else len(ds))):
        last_update_time = time.time() # to account for saves and such
        entry = ds[ds_idx]
        # TODO: super crutchy, might not hold, would make sense to make a prompt class e.g.
        # prompter :PromptBuilder, construct a prompt
        txt_x_no_prompt, input_text = example_to_prompt(entry, prompter, ds)
        prompt = ds.get_problem_system_instruction()
        # print(txt_x_no_prompt, input_text)

        if 'X' in model_id and 'x' in model_id:
            torch.cuda.set_device(use_device)
        toks = tokenizer([input_text,]*1, add_special_tokens=True, return_tensors='pt').to(device=use_device)
        # print(toks.input_ids.device)
        oall = model(toks['input_ids'].to(device=use_device))
        prompt_seq = toks.input_ids[0].cpu()
        prompt_logits = oall.logits[0].cpu()
        prompt_seltok_logits = torch.stack([torch.log_softmax(o, -1)[ids] for o, ids in zip(prompt_logits, prompt_seq)]).cpu()

        tqdm_range_pbar.set_description(
            desc=f"Prefill: {toks['input_ids'].shape[1]:04d} tokens / {time.time()-last_update_time:.2f} sec -> {toks['input_ids'].shape[1]/(time.time()-last_update_time):.2f} tps",
            refresh=True
        )
        last_update_time = time.time()
        del oall # important to free the memory
        # except Exception as e:
        #     print(f"Failed to compute logits for the prompt, error: {e}")
        #     prompt_seltok_logits = None

        # tokenize the input
        # try:

        for roll_name, roll_args in rollouts.items():
            # prepare logits processors
            logitproc = prepare_processors(logitproc=logitproc, tokenizer=tokenizer, vocab_size=len(tokenizer.vocab))
            if stop_on_tokens is None:
                stop_on_tokens = extract_stop_tokens_from_processors(logitproc) # that we can recompute indexes correctly for beam search / sampling

            # tokenize and batch the required number of replicates of the sequences
            toks = tokenizer([input_text,]*roll_args['multiply_seqs_when_batching'], add_special_tokens=True, return_tensors='pt').to(device=use_device)
            
            # generate based on the sequences
            outputs_all = model.generate(
                **toks, 
                output_hidden_states=True,
                output_scores=True,
                return_dict_in_generate=True,
                # stopping_criteria=additional_stopping_crits
                logits_processor=logitproc,
                **roll_args['generate_kwargs'],
            )
            # scores are logits after all of the logits processing (those are some fat matrices, log sm them on the gpu)
            ret_logits = [torch.log_softmax(ol, dim=-1).detach().cpu() for ol in outputs_all.scores]
            # now put the outputs to the cpu
            outputs_all = sn(**any_to_device_nested(outputs_all, device='cpu', raise_if_no_match=False)) # some generate outputs return cache objects as past key values, do not raise if no match
            ret_ids = outputs_all.sequences
            # how the hidden states are returned
            # 1. generated tokens; 2. layers; 3. [batch_size, n_tokens_processed, dim]
            # now the other part is trickier, since sequences do not have to be the same length
            input_prompt_len = toks.input_ids.shape[-1]
            # since the prompt may contain EOS for some models, have to scan only through the generated token
            batch_lens = get_generation_length_ids_from_batch(outputs_all.sequences[:, input_prompt_len:], eos_token_id=model.config.eos_token_id)
            # batch_lens = batch_lens # get the index of the last non eos generated token
            # try:
            # except Exception as e:
            #     print('____!____!____!____!')
            #     print(batch_lens, input_prompt_len, ds_idx)
            #     raise e
            # ends up being [n_samples, hidden_size]
            # and to top it off save the query and the resulting text separately
            full_x_query = input_text
            full_xy_response = outputs_all.sequences
            # optionally return all hiddens (not needed most of the time, configured to not be done by default)
            if roll_args['store_model_state']:
                ret_hs_at_x = outputs_all.hidden_states[0][-1][0,-1,:] # save just the last layer state of the prompt+input
                ret_hs_at_xy = [
                    # the very first entry is the prompt+input # TODO: not 100% sure -1 is necessary
                    outputs_all.hidden_states[ltid-1][-1][sample_id,-1,:] for sample_id, ltid in enumerate(batch_lens)
                ] # save all of the final states for when the generation is done
                ret_hs_at_xy = torch.stack(ret_hs_at_xy, dim=0)
                all_hidden = outputs_all.hidden_states
            else:
                all_hidden = None
            # compute transition probs from beam search output (pain in the ass, usually)
            if roll_args['generate_kwargs']['num_beams'] == 1:
                transition_scores = model.compute_transition_scores(
                    outputs_all.sequences, ret_logits, normalize_logits=False
                )
            else:
                # reassigning not necessary - bad token ids are already -inf and should not have been selected
                # tscores does not accumulate stuff - therefore doesn't care about eos_token_id
                transition_scores = model.compute_transition_scores(
                    # reassign_ids(outputs_all.sequences, stop_on_tokens, model.config.eos_token_id, starting_from=input_prompt_len), 
                    outputs_all.sequences,
                    ret_logits,
                    outputs_all.beam_indices,
                    normalize_logits=False
                )
            
            local_retdict = {}

            local_retdict.update({
                # token outputs and transition probabilities
                'sequences': ret_ids, #  all ids
                'sequences_no_input': ret_ids[:, input_prompt_len:], # ids without prompt
                'sequences_len': batch_lens, # length of each sampled output
                'seq_no_input_transition': transition_scores, # these are -(selected token nll for corresponfing positions)
                # decoded text versions
                'tok_x': prompt_seq,
                'tok_x_nlls': prompt_seltok_logits,
                'txt_prompt': prompt,
                'txt_x': input_text,
                'txt_x_no_prompt': txt_x_no_prompt,
                'txt_y': tokenizer.batch_decode(outputs_all.sequences[:, input_prompt_len:], skip_special_tokens=True),
                'txt_xy': tokenizer.batch_decode(outputs_all.sequences, skip_special_tokens=True),
                'txt_xy_full': tokenizer.batch_decode(outputs_all.sequences, skip_special_tokens=False),
                # backmapping info
                'dataset_idx': ds_idx,
                'model': model_id,
                'dataset': dset_id,
                'split': dset_split,
                # sampling parameters
                'sampling_type': 'MS' if roll_args['generate_kwargs']['num_beams'] is None else 'BS',
                'rollout_config_name': roll_name,
                'status': 'good',
            })
            # optionally save logtis
            if roll_args['store_logits']:
                local_retdict.update({
                    'logits': ret_logits
                })
            # optionally store model hidden states
            if roll_args['store_model_state']:
                local_retdict.update({
                    'hidden_at_x': ret_hs_at_x,
                    'hidden_at_xy': ret_hs_at_xy,
                    'all_hidden_states': all_hidden,
                })            
            reslists[roll_name].append(local_retdict)

            # print(local_retdict['txt_x'].replace('\n', '[nl]'))
            # print('--------')
            # print(local_retdict['txt_y'][1].replace('\n', '[nl]'))
            # print(local_retdict['sequences_no_input'][1, :5])
            tqdm_range_pbar.set_description(
                desc=f"{roll_name}: {batch_lens.sum().item():04d} tokens / {time.time()-last_update_time:.2f} sec -> {batch_lens.sum().item()/(time.time()-last_update_time):.2f} tps",
                refresh=True
            )
            last_update_time = time.time()

        
        if (ds_idx+1-dset_range_start)%save_every == 0 and save_every > 0:
            wait_for_lock_file_nice(dump_file_lock)
            # with cProfile.Profile() as pr:
            try:
                for roll_name, reslist in reslists.items():
                    save_records_compact(
                        reslist, 
                        f"{save_dir}/{roll_name}_{last_checkpoint:05d}_{ds_idx:05d}", 
                        digitize_conf={'logits': {'table':'lsm'}},
                        use_pigz=True,
                        pigz_fast_temp_path=pigz_temp_storage
                    )
                    reslists[roll_name] = []
                last_checkpoint=ds_idx+1
            except Exception as e:
                raise e
            finally:                
                if dump_file_lock.is_locked:
                    dump_file_lock.release()
                # pr.print_stats(sort='cumtime')
                # exit(0)

    # final piece
    wait_for_lock_file_nice(dump_file_lock)
    try:
        for roll_name, reslist in reslists.items():
            if len(reslist) > 0:
                save_records_compact(
                    reslist, 
                    f"{save_dir}/{roll_name}_{last_checkpoint:05d}_{ds_idx:05d}", 
                    digitize_conf={'logits': {'table':'lsm'}},
                    use_pigz=True,
                    pigz_fast_temp_path=pigz_temp_storage
                )
                reslists[roll_name] = []
        last_checkpoint=ds_idx+1
    except Exception as e:
        raise e
    finally:
        if dump_file_lock.is_locked:
            dump_file_lock.release()
    last_checkpoint=ds_idx

    print("Great succ!")