import random

import jax.tree_util
import numpy as np
import pandas as pd
from absl import flags
from datasets import load_dataset
from tqdm import tqdm

from EasyLM import logging_utils
from structured_prompting.dataset import get_dataset
from structured_prompting.utils.functional import setup_seed

flags.DEFINE_boolean("downstream_eval_verbose", False, "Prints the model input")
flags.DEFINE_integer("max_considered_length", 0, "Max length of prompts, for padding")
flags.DEFINE_boolean("mmlu_use_valid_examples", False, "Use validation examples")
FLAGS = flags.FLAGS


def toy_task_eval_step(
    prompt,
    answer,
    train_state,
    sharded_rng,
    augmented_eval_step,
    clear_xl_fn,
    tokenizer,
    max_seq_len,
):
    """Evaluates the model on a toy task."""
    full_input_tokens = tokenizer.encode(prompt + answer)
    answer_tokens = tokenizer.encode(answer)

    sharded_rng, task_metrics, train_state = verify(
        full_input_tokens,
        answer_tokens,
        max_seq_len,
        train_state,
        sharded_rng,
        augmented_eval_step,
        clear_xl_fn,
        tokenizer,
    )

    return train_state, sharded_rng, task_metrics


def verify(
    full_input_tokens,
    answer_tokens,
    max_chunk_len,
    train_state,
    sharded_rng,
    augmented_eval_step,
    clear_xl_fn,
    tokenizer,
):
    assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False
    input_tokens = [tokenizer.bos_token_id] + full_input_tokens
    # We wanna get rid of leading spaces when considering the answer - whitespace is no big deal.
    answer_tokens = tokenizer.encode(tokenizer.decode(answer_tokens))
    space_id = tokenizer.encode(" ")[0]
    while answer_tokens[0] == space_id:
        answer_tokens = answer_tokens[1:]
    answer_len = len(answer_tokens)

    if FLAGS.downstream_eval_verbose:
        print("Full model input:\n", tokenizer.decode(input_tokens))
        print("Input tokens length: ", len(input_tokens))
    if len(input_tokens) > max_chunk_len:
        train_state = clear_xl_fn(train_state)
        rest_suffix_len = len(input_tokens) % max_chunk_len
        prefix, rest_suffix = (
            input_tokens[:-rest_suffix_len],
            input_tokens[-rest_suffix_len:],
        )
        assert len(prefix) % max_chunk_len == 0
        # Add segments to the XL cache
        for i in range(0, len(prefix), max_chunk_len):
            prefix_segment = prefix[i : i + max_chunk_len]
            batch = {
                "input_tokens": np.array(prefix_segment),
                "target_tokens": np.array(prefix_segment),
                "loss_masks": np.ones_like(prefix_segment).astype(np.float32),
            }
            batch = jax.tree_util.tree_map(lambda x: np.expand_dims(x, 0), batch)
            train_state, sharded_rng, _ = augmented_eval_step(
                train_state=train_state, rng=sharded_rng, batch=batch
            )
        input_tokens = rest_suffix
    assert len(input_tokens) < max_chunk_len
    assert input_tokens[-len(answer_tokens) :] == answer_tokens
    if FLAGS.downstream_eval_verbose:
        print("Input suffix length: ", len(input_tokens))
    padded_tokens = np.pad(
        input_tokens,
        (0, max_chunk_len - len(input_tokens)),
        mode="constant",
        constant_values=tokenizer.pad_token_id,
    )
    target_tokens = np.concatenate([padded_tokens[1:], [tokenizer.pad_token_id]])
    batch = {
        "input_tokens": padded_tokens,
        "target_tokens": target_tokens,
        "loss_masks": np.concatenate(
            [
                np.zeros(
                    len(input_tokens) - answer_len - 1
                ),  # minus two because of bos and eos
                np.ones(answer_len),
                np.zeros(max_chunk_len - len(input_tokens) + 1),
            ]
        ),
    }
    batch = jax.tree_util.tree_map(lambda x: np.expand_dims(x, 0), batch)
    bool_mask = batch["loss_masks"][0].astype(bool)
    assert list(target_tokens[bool_mask]) == answer_tokens
    sharded_rng, metrics = augmented_eval_step(train_state, sharded_rng, batch)
    token_level_accuracy = metrics.pop("token_level_accuracy")
    token_level_loss = metrics.pop("token_level_loss")
    masked_accuracy = token_level_accuracy[0][bool_mask]
    masked_loss = token_level_loss[0][bool_mask]
    if FLAGS.downstream_eval_verbose:
        print("Target chars: ", tokenizer.decode(answer_tokens))
        print("Target tokens: ", answer_tokens)
        print("Accuracy: ", masked_accuracy)
    sequence_accuracy = np.alltrue(masked_accuracy).astype(np.float32)
    token_accuracy = np.mean(masked_accuracy)
    token_loss = np.mean(masked_loss)
    task_metrics = dict(
        retrieval_sequence_accuracy=sequence_accuracy,
        retrieval_token_accuracy=token_accuracy,
        retrieval_token_loss=token_loss,
    )
    return sharded_rng, task_metrics, train_state


def toy_task_eval(
    input_generating_fn,
    train_state,
    sharded_rng,
    augmented_eval_step,
    clear_xl_fn,
    tokenizer,
    max_seq_len,
    n_eval_samples=100,
):
    random.seed(2137)

    metrics_list = []
    for _ in tqdm(range(n_eval_samples)):
        prompt, answer = input_generating_fn()

        train_state, sharded_rng, metrics = toy_task_eval_step(
            prompt,
            answer,
            train_state,
            sharded_rng,
            augmented_eval_step,
            clear_xl_fn,
            tokenizer,
            max_seq_len,
        )
        metrics_list.append(metrics)
    return train_state, metrics_list


def downstream_task_eval(
    dataset,
    train_state,
    sharded_rng,
    augmented_eval_step,
    clear_xl_fn,
    tokenizer,
    window_len,
    demo_max_length,
    n_shots=0,
    n_eval_samples=100,
    repeat_num=1,
):
    dataset_train = get_dataset(dataset, is_train=True)
    dataset_val = get_dataset(dataset, is_train=False, max_data_num=n_eval_samples)

    max_input_answer_len = get_maximum_input_plus_answer_length(dataset_val, tokenizer)
    print("Maximum input + answer length: ", max_input_answer_len)
    demo_max_length = demo_max_length - max_input_answer_len - 1
    max_prompt_len = FLAGS.max_considered_length - max_input_answer_len - 1

    per_trial_metrics, inter_trial_metrics = [], []
    for seed in tqdm(range(repeat_num), desc=f"Evaluating with {repeat_num} seeds"):
        setup_seed(seed)
        (
            demo_encoding_batch,
            attention_mask_batch,
            num_examples,
        ) = dataset_train.get_chunk(
            tokenizer,
            demo_max_length,
            fake_max_length=max_prompt_len,
            strategy="skip",
            chunk_num=None,
            shot=n_shots,
        )
        print("Number of examples: ", num_examples)
        assert len(demo_encoding_batch) == len(attention_mask_batch) == 1
        assert len(demo_encoding_batch[0]) <= demo_max_length
        demonstration_tokens = demo_encoding_batch[0]

        metrics_list = []
        for input_str, output_str, answer in dataset_val:
            # TODO: support multiple-choice tasks
            # input_example_tokens = tokenizer.encode(input_str)
            if answer != -1:
                answer_str = output_str[answer]
            else:
                answer_str = output_str
            answer_tokens = tokenizer.encode(answer_str)

            full_input_tokens = demonstration_tokens + tokenizer.encode(
                input_str + answer_str
            )

            sharded_rng, task_metrics, train_state = verify(
                full_input_tokens,
                answer_tokens,
                window_len,
                train_state,
                sharded_rng,
                augmented_eval_step,
                clear_xl_fn,
                tokenizer,
            )
            task_metrics = {f"downstream/{k}": v for k, v in task_metrics.items()}
            metrics_list.append(task_metrics)

        accuracy_per_trial = np.mean(
            [m["downstream/retrieval_sequence_accuracy"] for m in metrics_list]
        )
        per_trial_metrics.extend(metrics_list)
        inter_trial_metrics.append(
            {"downstream/accuracy_per_trial": accuracy_per_trial}
        )

    return train_state, per_trial_metrics, inter_trial_metrics


def get_maximum_input_plus_answer_length(dataset, tokenizer):
    max_len = 0
    for input_str, output_str, answer in dataset:
        if answer != -1:
            answer_str = output_str[answer]
        else:
            answer_str = output_str
        input_answer_tokens = tokenizer.encode(input_str + answer_str)
        max_len = max(max_len, len(input_answer_tokens))

    return max_len


### MMLU evaluation ###
MMLU_SUBJECT_LIST = sorted(
    [
        "high_school_european_history",
        "business_ethics",
        "clinical_knowledge",
        "medical_genetics",
        "high_school_us_history",
        "high_school_physics",
        "high_school_world_history",
        "virology",
        "high_school_microeconomics",
        "econometrics",
        "college_computer_science",
        "high_school_biology",
        "abstract_algebra",
        "professional_accounting",
        "philosophy",
        "professional_medicine",
        "nutrition",
        "global_facts",
        "machine_learning",
        "security_studies",
        "public_relations",
        "professional_psychology",
        "prehistory",
        "anatomy",
        "human_sexuality",
        "college_medicine",
        "high_school_government_and_politics",
        "college_chemistry",
        "logical_fallacies",
        "high_school_geography",
        "elementary_mathematics",
        "human_aging",
        "college_mathematics",
        "high_school_psychology",
        "formal_logic",
        "high_school_statistics",
        "international_law",
        "high_school_mathematics",
        "high_school_computer_science",
        "conceptual_physics",
        "miscellaneous",
        "high_school_chemistry",
        "marketing",
        "professional_law",
        "management",
        "college_physics",
        "jurisprudence",
        "world_religions",
        "sociology",
        "us_foreign_policy",
        "high_school_macroeconomics",
        "computer_security",
        "moral_scenarios",
        "moral_disputes",
        "electrical_engineering",
        "astronomy",
        "college_biology",
    ]
)

# Hendrycks code (slightly modified)
choices = ["A", "B", "C", "D"]


def format_subject(subject):
    l = subject.split("_")  # noqa: E741
    s = ""
    for entry in l:
        s += " " + entry
    return s


def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = len(df.iloc[idx]["choices"])
    assert k == 4
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx]["choices"][j])
    prompt += "\nAnswer: "
    fmt = "{}\n\n"
    if not include_answer:
        fmt = "{}"
    answer = fmt.format(choices[df.iloc[idx]["answer"]])
    if include_answer:
        prompt += answer
    return prompt, answer


def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)[0]
    return prompt


def mmlu_eval(
    train_state,
    sharded_rng,
    augmented_eval_step,
    clear_xl_fn,
    tokenizer,
    window_len,
    demo_max_length,
    n_shots=5,
    n_eval_samples=100,
    repeat_num=1,
):
    def evaluate_subject(
        subject, train_state, sharded_rng, augmented_eval_step, global_metrics_list
    ):
        hf_ds = load_dataset("cais/mmlu", subject)
        # TODO: dev should be train + validation instead of just validation!!!
        dev_df = hf_ds["dev"].to_pandas()
        if FLAGS.mmlu_use_valid_examples:
            dev_valid = pd.concat([dev_df, hf_ds["validation"].to_pandas()])
            assert dev_valid.shape[0] > dev_df.shape[0]
            dev_df = dev_valid
        test_df = hf_ds["test"].to_pandas()
        print(f"Few shot prompting on {dev_df.shape[0]} examples")

        metrics_list = []
        # for i in range(1):
        for i in tqdm(range(test_df.shape[0]), desc=f"evaluating {subject}"):
            # get prompt and make sure it fits
            k = min(n_shots, dev_df.shape[0])
            input_str, answer_str = format_example(test_df, i, include_answer=False)
            demonstration_str = gen_prompt(dev_df, subject, k)
            demonstration_tokens = tokenizer.encode(demonstration_str)

            input_example_tokens = tokenizer.encode(input_str)
            answer_tokens = tokenizer.encode(answer_str)
            assert len(answer_tokens) == 1  # ABC tasks only

            while (
                len(demonstration_tokens + input_example_tokens + answer_tokens)
                >= 2047  # bos margin
            ):
                k -= 1
                demonstration_str = gen_prompt(dev_df, subject, k)
                demonstration_tokens = tokenizer.encode(demonstration_str)

            # sanity that answer doesnt contain whitespace
            assert tokenizer.decode(tokenizer.encode(answer_str)) == answer_str
            full_input_tokens = tokenizer.encode(
                demonstration_str + input_str + answer_str
            )

            sharded_rng, task_metrics, train_state = verify(
                full_input_tokens,
                answer_tokens,
                window_len,
                train_state,
                sharded_rng,
                augmented_eval_step,
                clear_xl_fn,
                tokenizer,
            )
            # input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
            global_metrics = {f"mmlu_all/{k}": v for k, v in task_metrics.items()}
            task_metrics = {f"mmlu/{subject}/{k}": v for k, v in task_metrics.items()}
            metrics_list.append(task_metrics)
            global_metrics_list.append(global_metrics)
        return train_state, metrics_list, sharded_rng

    global_metrics_list = []
    per_task_metrics = {}

    for subject in MMLU_SUBJECT_LIST:
        train_state, local_metrics_list, sharded_rng = evaluate_subject(
            subject, train_state, sharded_rng, augmented_eval_step, global_metrics_list
        )
        agg = logging_utils.LogAggregator(provide_latest=False, device_get_at_add=False)
        agg.add_list(local_metrics_list)
        task_logs = agg.get_logs()
        print("Logs for task: ", subject)
        print(task_logs)
        per_task_metrics.update(task_logs)

    return train_state, global_metrics_list, per_task_metrics
