from util.logger import logger

from typing import Optional, Union, List, Tuple

from pathlib import Path

import numpy as np

import torch

import gc

from transformers import AutoModelForCausalLM, AutoTokenizer

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import load_yaml


def load_promptist_model(
    cfg_yaml_path: Optional[str] = None, 

    device: Optional[str] = "cpu"
) -> Tuple[
    "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel", 
    "transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast"
]:
    cfg_yaml_path = cfg_yaml_path or "./config/model/promptist.yaml"
    cfg_dict_promptist = load_yaml(yaml_path = cfg_yaml_path)
    
    ckpt_root_path_promptist = cfg_dict_promptist["promptist"]["ckpt_root_path"]

    gpt_2_cfg_yaml_path = cfg_dict_promptist["promptist"]["gpt_2_cfg_yaml_path"]
    cfg_dict_gpt_2 = load_yaml(yaml_path = gpt_2_cfg_yaml_path)

    ckpt_root_path_gpt_2 = cfg_dict_gpt_2["gpt_2"]["ckpt_root_path"]

    promptist_model = AutoModelForCausalLM.from_pretrained(ckpt_root_path_promptist) \
        .to(device)
    
    # process on CPU
    promptist_tokenizer = AutoTokenizer.from_pretrained(ckpt_root_path_gpt_2)

    promptist_tokenizer.pad_token = promptist_tokenizer.eos_token
    promptist_tokenizer.padding_side = "left"

    # ---------= [Clean Up] =---------
    del cfg_dict_promptist
    del cfg_dict_gpt_2
    gc.collect()

    # `load_promptist_model()` done
    return (
        promptist_model, 
        promptist_tokenizer
    )


def optimize_text_prompt(
    prompt: str, 

    promptist_model, 
    promptist_tokenizer
) -> str:
    """
    Func: 
        Optimize text prompt `prompt` with Promptist. 

    Ret:
        `optimized_prompt` (`str`): The optimized prompt. 
    """
    # input_id_list.shape = (1, 16)

    input_id_list = promptist_tokenizer(
        prompt.strip() + " Rephrase: ", 
        
        return_tensors = "pt"
    ).input_ids \
        .to(promptist_model.device)

    # eos_id = 50256
    eos_id = promptist_tokenizer.eos_token_id

    # output_list.shape = (num_beams, 71)
    output_list = promptist_model.generate(
        input_id_list, 

        do_sample = False, 
        max_new_tokens = 75, 
        num_beams = 8, 
        num_return_sequences = 8, 
        eos_token_id = eos_id, 
        pad_token_id = eos_id, 
        length_penalty = -1.0
    )

    # len(output_text_list) = num_beams
    output_text_list = promptist_tokenizer.batch_decode(
        output_list, 

        skip_special_tokens = True
    )
    
    optimized_prompt = output_text_list[0].replace(prompt + " Rephrase:", "").strip()

    # ---------= [Clean Up] =---------
    del input_id_list
    del output_list
    del output_text_list
    torch.cuda.empty_cache()
    gc.collect()

    # `optimize_text_prompt()` done
    return optimized_prompt
