import numpy as np

dict_of_lists_to_list_of_dicts = lambda d: [
    dict(zip(d.keys(), vals)) for vals in zip(*d.values())
]
list_of_dicts_to_dict_of_lists = lambda l: {k: [d[k] for d in l] for k in l[0]}


def adaboost_sampling(
    train_dataset,
    t,
    num_proc,
    is_weight_by_token,
    probability_bias,
    token_prob_window_size,
):
    if not is_weight_by_token:
        probs = np.array(train_dataset[f"adaboost_weight_{t}"])
        probs = probs / probs.sum()
        probs += probability_bias
        probs = probs / probs.sum()
        sampled_indices = np.random.choice(
            np.arange(len(train_dataset)),
            size=len(train_dataset),
            replace=True,
            p=probs,
        )
        adaboost_count = np.bincount(sampled_indices, minlength=len(train_dataset))
        train_dataset = train_dataset.map(
            lambda _, idx: {f"adaboost_count_{t}": adaboost_count[idx]},
            with_indices=True,
            num_proc=num_proc,
            desc=f"Round {t}: adaboost sampling",
        )
    else:
        ws = token_prob_window_size
        weights = [
            np.array(
                [
                    np.sum(w[max(0, i - ws // 2) : min(len(w), i + ws // 2 + 1)])
                    / len(w[max(0, i - ws // 2) : min(len(w), i + ws // 2 + 1)])
                    for i in range(len(w))
                ]
            )
            for w in train_dataset[f"adaboost_weight_{t}"]
        ]
        probs = np.concatenate(weights)
        probs = probs / probs.sum()
        probs += probability_bias
        probs = probs / probs.sum()
        row_lengths = [
            len(weights) for weights in train_dataset[f"adaboost_weight_{t}"]
        ]
        sampled_indices = np.random.choice(
            np.arange(len(probs)),
            size=len(probs),
            replace=True,
            p=probs,
        )
        counts = np.bincount(sampled_indices, minlength=len(probs))
        adaboost_count = np.split(counts, np.cumsum(row_lengths[:-1]))
        train_dataset = train_dataset.map(
            lambda _, idx: {f"adaboost_count_{t}": adaboost_count[idx]},
            with_indices=True,
            num_proc=num_proc,
            desc=f"Round {t}: adaboost sampling",
        )
    return train_dataset


def reformat_dataset(train_dataset, t, num_proc, is_weight_by_token):
    if not is_weight_by_token:

        def repeat_rows(example):
            count = example[f"adaboost_count_{t}"][0]
            if count == 0:
                return {column_name: [] for column_name in example}
            return {
                column_name: example[column_name] * count
                for column_name in example
                if column_name != f"adaboost_count_{t}"
            } | {f"adaboost_count_{t}": [1] * count}

        new_dataset = train_dataset.map(
            repeat_rows,
            batched=True,
            batch_size=1,
            num_proc=num_proc,
            remove_columns=train_dataset.column_names,
            desc=f"Round {t}: adaboost reformatting dataset",
        )
    else:

        def split_rows(example):
            count_list = example[f"adaboost_count_{t}"][0]
            max_count = max(count_list)
            if max_count == 0:
                return {column_name: [] for column_name in example}
            new_rows = []
            for i in range(max_count):
                new_row = {
                    column_name: example[column_name][0]
                    for column_name in example
                    if column_name != f"adaboost_count_{t}"
                }
                new_count_list = [1 if count > i else 0 for count in count_list]
                new_row[f"adaboost_count_{t}"] = new_count_list
                new_rows.append(new_row)
            return list_of_dicts_to_dict_of_lists(new_rows)

        new_dataset = train_dataset.map(
            split_rows,
            batched=True,
            batch_size=1,
            num_proc=num_proc,
            remove_columns=train_dataset.column_names,
            desc=f"Round {t}: adaboost reformatting dataset",
        )

    return new_dataset


def replace_option_dataset(transfer_dataset, t, num_proc):
    def func(example):
        return {
            "response": example["response_list"][example[f"option_target_{t}"]],
            "text": example["text_list"][example[f"option_target_{t}"]],
            "input_ids": example["input_ids_list"][example[f"option_target_{t}"]],
            "attention_mask": example["attention_mask_list"][
                example[f"option_target_{t}"]
            ],
            "start_index": example["start_index_list"][example[f"option_target_{t}"]],
            "end_index": example["end_index_list"][example[f"option_target_{t}"]],
            "sft_ids": example["sft_ids_list"][example[f"option_target_{t}"]],
            "sft_labels": example["sft_labels_list"][example[f"option_target_{t}"]],
            "sft_options": example["sft_options_list"][example[f"option_target_{t}"]],
        }

    transfer_dataset = transfer_dataset.map(
        func,
        batched=False,
        num_proc=num_proc,
        desc=f"Round {t}: replace options on transfer dataset",
    )
    return transfer_dataset


def calculate_avg_unweighted_error(dataset, t, is_weight_by_token, mode="weak"):
    prefix = "error_" if mode == "weak" else "error_strong_"
    if f"{prefix}{t}" not in dataset.column_names:
        raise RuntimeError(f"{prefix}{t} column not found in dataset.")
    if not is_weight_by_token:
        avg_unweighted_error = np.mean(dataset[f"{prefix}{t}"])
    else:
        avg_unweighted_error = np.mean(sum(dataset[f"{prefix}{t}"], []))
    return avg_unweighted_error


def calculate_avg_weighted_error(train_dataset, t, is_weight_by_token):
    if f"error_{t}" not in train_dataset.column_names:
        raise RuntimeError(f"error_{t} column not found in dataset.")
    if not is_weight_by_token:
        weights = np.array(train_dataset[f"adaboost_weight_{max(t-1,0)}"])
        weights = weights / weights.sum()
        avg_weighted_error = np.mean(weights * np.array(train_dataset[f"error_{t}"]))
    else:
        weights = np.concatenate(train_dataset[f"adaboost_weight_{max(t-1,0)}"])
        weights = weights / weights.sum()
        avg_weighted_error = np.mean(
            weights * np.concatenate(train_dataset[f"error_{t}"])
        )
    return avg_weighted_error


def calculate_avg_option_error(dataset, t, mode="weak"):
    prefix = "option_error_" if mode == "weak" else "option_error_strong_"
    if f"{prefix}{t}" not in dataset.column_names:
        raise RuntimeError(f"{prefix}{t} column not found in dataset.")
    avg_option_error = np.mean(dataset[f"{prefix}{t}"])
    return avg_option_error


def calculate_adaboost_k(train_dataset, is_weight_by_token):
    adaboost_k = 1.0 / (
        1.0
        - calculate_avg_unweighted_error(
            train_dataset, t=0, is_weight_by_token=is_weight_by_token
        )
    )
    return adaboost_k


def calculate_adaboost_alphas(train_dataset, t, is_weight_by_token):
    if t <= 1:
        return [1.0]
    adaboost_k = calculate_adaboost_k(train_dataset, is_weight_by_token)
    adaboost_alphas = []
    for inner_t in range(1, t + 1):
        avg_weighted_error = calculate_avg_weighted_error(
            train_dataset, t=inner_t, is_weight_by_token=is_weight_by_token
        )
        adaboost_alpha = np.log((1 - avg_weighted_error) / avg_weighted_error) + np.log(
            adaboost_k - 1
        )
        adaboost_alphas.append(adaboost_alpha)
    return adaboost_alphas


def update_adaboost_weight(train_dataset, t, num_proc, is_weight_by_token):
    if f"adaboost_weight_{t-1}" not in train_dataset.column_names:
        raise RuntimeError(f"adaboost_weight_{t-1} column not found in dataset.")

    adaboost_alpha = calculate_adaboost_alphas(train_dataset, t, is_weight_by_token)[-1]

    def func(example):
        if not is_weight_by_token:
            adaboost_weight = example[f"adaboost_weight_{t-1}"] * np.exp(
                adaboost_alpha * example[f"error_{t}"]
            )
        else:
            adaboost_weight = example[f"adaboost_weight_{t-1}"] * np.exp(
                adaboost_alpha * np.array(example[f"error_{t}"])
            )
        return {f"adaboost_weight_{t}": adaboost_weight}

    train_dataset = train_dataset.map(
        func,
        batched=False,
        num_proc=num_proc,
        desc=f"Round {t}: updating adaboost weights",
    )

    return train_dataset
