"""Function returning clinc150 datasets compatible with some other places.

https://huggingface.co/datasets/contemmcm/clinc150
"""
import re
from typing import Any, Dict, List, Optional, Sequence, Union

import datasets
from lighteval.tasks.requests import Doc
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common


###############################################################################
###############################################################################

# The index of the label here is the integral representation of each label.
# The part after the colon is unique amongst all of the labels.
LABELS = ('oos:oos', 'banking:freeze_account', 'banking:routing', 'banking:pin_change', 'banking:bill_due', 'banking:pay_bill', 'banking:account_blocked', 'banking:interest_rate', 'banking:min_payment', 'banking:bill_balance', 'banking:transfer', 'banking:order_checks', 'banking:balance', 'banking:spending_history', 'banking:transactions', 'banking:report_fraud', 'credit_cards:replacement_card_duration', 'credit_cards:expiration_date', 'credit_cards:damaged_card', 'credit_cards:improve_credit_score', 'credit_cards:report_lost_card', 'credit_cards:card_declined', 'credit_cards:credit_limit_change', 'credit_cards:apr', 'credit_cards:redeem_rewards', 'credit_cards:credit_limit', 'credit_cards:rewards_balance', 'credit_cards:application_status', 'credit_cards:credit_score', 'credit_cards:new_card', 'credit_cards:international_fees', 'kitchen_and_dining:food_last', 'kitchen_and_dining:confirm_reservation', 'kitchen_and_dining:how_busy', 'kitchen_and_dining:ingredients_list', 'kitchen_and_dining:calories', 'kitchen_and_dining:nutrition_info', 'kitchen_and_dining:recipe', 'kitchen_and_dining:restaurant_reviews', 'kitchen_and_dining:restaurant_reservation', 'kitchen_and_dining:meal_suggestion', 'kitchen_and_dining:restaurant_suggestion', 'kitchen_and_dining:cancel_reservation', 'kitchen_and_dining:ingredient_substitution', 'kitchen_and_dining:cook_time', 'kitchen_and_dining:accept_reservations', 'home:what_song', 'home:play_music', 'home:todo_list_update', 'home:reminder', 'home:reminder_update', 'home:calendar_update', 'home:order_status', 'home:update_playlist', 'home:shopping_list', 'home:calendar', 'home:next_song', 'home:order', 'home:todo_list', 'home:shopping_list_update', 'home:smart_home', 'auto_and_commute:current_location', 'auto_and_commute:oil_change_when', 'auto_and_commute:oil_change_how', 'auto_and_commute:uber', 'auto_and_commute:traffic', 'auto_and_commute:tire_pressure', 'auto_and_commute:schedule_maintenance', 'auto_and_commute:gas', 'auto_and_commute:mpg', 'auto_and_commute:distance', 'auto_and_commute:directions', 'auto_and_commute:last_maintenance', 'auto_and_commute:gas_type', 'auto_and_commute:tire_change', 'auto_and_commute:jump_start', 'travel:plug_type', 'travel:travel_notification', 'travel:translate', 'travel:flight_status', 'travel:international_visa', 'travel:timezone', 'travel:exchange_rate', 'travel:travel_suggestion', 'travel:travel_alert', 'travel:vaccines', 'travel:lost_luggage', 'travel:book_flight', 'travel:book_hotel', 'travel:carry_on', 'travel:car_rental', 'utility:weather', 'utility:alarm', 'utility:date', 'utility:find_phone', 'utility:share_location', 'utility:timer', 'utility:make_call', 'utility:calculator', 'utility:definition', 'utility:measurement_conversion', 'utility:flip_coin', 'utility:spelling', 'utility:time', 'utility:roll_dice', 'utility:text', 'work:pto_request_status', 'work:next_holiday', 'work:insurance_change', 'work:insurance', 'work:meeting_schedule', 'work:payday', 'work:taxes', 'work:income', 'work:rollover_401k', 'work:pto_balance', 'work:pto_request', 'work:w2', 'work:schedule_meeting', 'work:direct_deposit', 'work:pto_used', 'small_talk:who_made_you', 'small_talk:meaning_of_life', 'small_talk:who_do_you_work_for', 'small_talk:do_you_have_pets', 'small_talk:what_are_your_hobbies', 'small_talk:fun_fact', 'small_talk:what_is_your_name', 'small_talk:where_are_you_from', 'small_talk:goodbye', 'small_talk:thank_you', 'small_talk:greeting', 'small_talk:tell_joke', 'small_talk:are_you_a_bot', 'small_talk:how_old_are_you', 'small_talk:what_can_i_ask_you', 'meta:change_speed', 'meta:user_name', 'meta:whisper_mode', 'meta:yes', 'meta:change_volume', 'meta:no', 'meta:change_language', 'meta:repeat', 'meta:change_accent', 'meta:cancel', 'meta:sync_device', 'meta:change_user_name', 'meta:change_ai_name', 'meta:reset_settings', 'meta:maybe')
LABELS_LAST_PART = tuple(x.split(':')[-1] for x in LABELS)

###############################################################################
###############################################################################

_NL_LABELS = "\n".join(LABELS_LAST_PART)

_OPEN_SUBTASK_TO_TEMPLATE = {
    'prompt_1': 'Query: {query}\nIn a few words, what is the intent of this query?\nIntent:',
    'prompt_2': f'Query: {{query}}\nWhat is the intent of this query? Choose from:\n{_NL_LABELS}.\nIntent:',
    #
    # These mirror the lm_suffix_mc tasks.
    'version_3': 'Query: {query}\nIntent:',
}


def load_default_open_task(
    task: str,
    # The subtask can be used to specify a pre-set template. It must be None if template is provided.
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: Optional[str] = None,
):
    """The task is open text generation.
    
    This won't really have a good way to evaluate accuracy on the task.
    """
    if task != 'clinc150':
        raise ValueError(f'The task must be "clinc150". Instead received: {task}')

    if split != 'complete':
        # NOTE: The dataset encodes split via the "split" feature.
        raise ValueError('Only currently supporting the "complete" split.')

    if template is None:
        template = _OPEN_SUBTASK_TO_TEMPLATE[subtask]
    else:
        assert subtask is None

    base_ds = datasets.load_dataset("contemmcm/clinc150")
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_open_example(tokenizer, ex, sequence_length, template=template),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'labels': ex['intent']},
        batched=False,
    )
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


def encode_open_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int,
    *,
    template: str,
) -> Dict[str, torch.Tensor]:
    query = example['text'].strip()
    context = template.format(query=query)

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, sequence_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, sequence_length),
    }


###############################################################################
###############################################################################


_LM_SUFFIX_MC_SUBTASK_TO_TEMPLATE = {
    'version_1': 'Query: {query}\nWhat is the intent of this query?\nIntent:',
    'version_2': 'Query: {query}\nIntent:',
    'version_3': 'Query: {query}\nIntent:',
    #
    # 'dummy': 'Query: {query}\nIntent:',
    # 'dummy': 'Query: {query}\nWhat is the intent of this query?',
    # 'dummy': 'The intent of the query "{query}" is',
    'dummy': 'What is the intent of the following query?\n{query}\nIntent:',
}
_LM_SUFFIX_MC_SUBTASK_TO_OPTIONS = {
    'version_1': LABELS_LAST_PART,
    'version_2': LABELS_LAST_PART,
    'version_3': tuple(x.replace('_', ' ') for x in LABELS_LAST_PART),
    #
    # 'dummy': LABELS,  # 0.09999999403953552
    # 'dummy': tuple(x.replace(':', ': ').replace('_', ' ') for x in LABELS),  # 0.17000000178813934
    # 'dummy': tuple(x.replace('_', ' ').capitalize() for x in LABELS_LAST_PART),  # 0.14999999105930328
    # 'dummy': tuple(x.upper() for x in LABELS_LAST_PART),  # 0.14999999105930328
    # 'dummy': tuple(x.replace('_', ' ') for x in LABELS_LAST_PART)[::-1],  # 0.009999999776482582
    'dummy': tuple(x.replace('_', ' ') for x in LABELS_LAST_PART),
}


def load_default_lm_suffix_mc_task(
    task: str,
    subtask: str,
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task != 'clinc150':
        raise ValueError(f'The task must be "clinc150". Instead received: {task}')

    if split != 'complete':
        # NOTE: The dataset encodes split via the "split" feature.
        raise ValueError('Only currently supporting the "complete" split.')

    template = _LM_SUFFIX_MC_SUBTASK_TO_TEMPLATE[subtask]
    options = list(_LM_SUFFIX_MC_SUBTASK_TO_OPTIONS[subtask])

    base_ds = datasets.load_dataset("contemmcm/clinc150")
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_suffix_mc_example(tokenizer, ex, sequence_length, template=template, options=options),
        batched=False,
    )

    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels', 'context_length'])

    return ds


###############################################################################


def encode_lm_suffix_mc_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int,
    *,
    template: str,
    options: List[str],
) -> Dict[str, torch.Tensor]:
    context = template.format(query=example['text'].strip())

    doc = Doc(
        query=context,
        choices=options,
        gold_index=example['intent'],
    )

    return preprocessing_common.encode_lm_suffix_mc_example(doc, tokenizer=tokenizer, sequence_length=sequence_length)


###############################################################################
###############################################################################


@torch.no_grad()
def reprocess_saved_lm_suffix_mc_first_example(
    # Must be a single example.
    example: Dict[str, torch.Tensor],
    *,
    tokenizer: PreTrainedTokenizer,
    subtask: str = 'version_3',
):
    """Converts an example saved with --lm_suffix_mc_save_only_first_example=true to an lm_suffix_mc_example"""
    assert subtask == 'version_3', "Currently only supporting subtask='version_3'."

    input_ids = example['input_ids'].detach().cpu().numpy()
    attention_mask = (example['attention_mask'] != 0).detach().cpu().numpy()

    # Make sure this is a single example.
    assert len(input_ids.shape) == len(attention_mask.shape) == 1

    sequence_length, = input_ids.shape

    og_text = tokenizer.decode(input_ids[attention_mask])

    match = re.search(r'^Query: (.+)\nIntent:', og_text)
    if not match:
        raise ValueError(f'Invalid example: {og_text}')

    query = match.group(1)

    template = _LM_SUFFIX_MC_SUBTASK_TO_TEMPLATE[subtask]
    options = list(_LM_SUFFIX_MC_SUBTASK_TO_OPTIONS[subtask])

    ret = encode_lm_suffix_mc_example(
        tokenizer=tokenizer,
        # The intent (i.e. label) here is a dummy value.
        example={'text': query, 'intent': 0},
        sequence_length=sequence_length,
        template=template,
        options=options,
    )
    del ret['labels']
    return ret
