from collections import OrderedDict

import numpy as np
import pandas as pd
import tqdm

from src.data.utils import CustomColName


def sort_chatml_msg_keys(example, **kwargs):
    """
    Sort the keys of the messages in the example to fix issue below:

    ArrowTypeError: struct fields don't match or are in the wrong order: Input
    fields: struct<role: string, content: string> output fields: struct<content:
    string, role: string>
    """
    original_roles = [m["role"] for m in example["messages"]]
    original_turn_cnt = len(example["messages"])

    orignal_messages = example.pop("messages")
    example[CustomColName.REORDERED_MSGS.value] = [OrderedDict(sorted(m.items())) for m in orignal_messages]

    # Sanity check, if this is failing the entire training will be wrong
    assert original_roles == [m["role"] for m in example[CustomColName.REORDERED_MSGS.value]]
    assert original_turn_cnt == len(example[CustomColName.REORDERED_MSGS.value])

    return example


def get_statistics_for_messages_data(
    dataset,
    split="train",
    messages_key="messages",
    tokenizer="philschmid/meta-llama-3-tokenizer",
):
    # get statistics
    num_instances = len(dataset[split])

    # remove any messages that have "role" == "system"
    def remove_system_messages(example):
        example[messages_key] = [message for message in example[messages_key] if message["role"] != "system"]
        return example

    dataset = dataset.map(remove_system_messages, num_proc=16)

    num_of_turns = [len(instance[messages_key]) for instance in dataset[split]]
    user_prompt_lengths = []
    assistant_response_lengths = []
    instance_lengths = []
    for instance in tqdm.tqdm(dataset[split], desc="Processing instances"):
        instance_length = 0
        for message in instance[messages_key]:
            if message["role"] == "user":
                user_prompt_lengths.append(len(message["content"]) / 3.6)
                instance_length += user_prompt_lengths[-1]
            elif message["role"] == "assistant":
                assistant_response_lengths.append(len(message["content"]) / 3.6)
                instance_length += assistant_response_lengths[-1]
        instance_lengths.append(instance_length)

    top_100_longest_instances = np.argsort(instance_lengths)[-100:][::-1].tolist()
    if "id" in dataset[split].features:
        top_100_longest_instances = [dataset[split][i]["id"] for i in top_100_longest_instances]
    else:
        top_100_longest_instances = None

    result = {
        "num_instances": num_instances,
        "turns_summary": pd.Series(num_of_turns).describe(),
        "user_prompt_lengths_summary": pd.Series(user_prompt_lengths).describe(),
        "assistant_response_lengths_summary": pd.Series(assistant_response_lengths).describe(),
        "total_lengths_summary": pd.Series(instance_lengths).describe(),
        "num_instances_with_total_length_gt_512": np.sum(np.array(instance_lengths) > 512),
        "num_instances_with_total_length_gt_768": np.sum(np.array(instance_lengths) > 768),
        "num_instances_with_total_length_gt_1024": np.sum(np.array(instance_lengths) > 1024),
        "num_instances_with_total_length_gt_1536": np.sum(np.array(instance_lengths) > 1536),
        "num_instances_with_total_length_gt_2048": np.sum(np.array(instance_lengths) > 2048),
        "num_instances_with_total_length_gt_4096": np.sum(np.array(instance_lengths) > 4096),
        "top_100_longest_instances": top_100_longest_instances,
    }

    # convert everything to dict or scalar
    for key, value in result.items():
        if isinstance(value, pd.Series):
            result[key] = value.to_dict()
        elif isinstance(value, np.ndarray):
            result[key] = value.tolist()
        elif isinstance(value, np.int64):
            result[key] = int(value)

    return result
