import numpy as np
import torch

from adaboost import (
    update_adaboost_weight,
    calculate_adaboost_alphas,
)


dict_of_lists_to_list_of_dicts = lambda d: [
    dict(zip(d.keys(), vals)) for vals in zip(*d.values())
]
list_of_dicts_to_dict_of_lists = lambda l: {k: [d[k] for d in l] for k in l[0]}
flatten_nested_list = lambda nested_list: (
    [item for sublist in nested_list for item in sublist],
    [len(sublist) for sublist in nested_list],
)
pack_nested_list = lambda flattened_list, lengths: [
    list(sublist) for sublist in np.split(flattened_list, np.cumsum(lengths)[:-1])
]


def predict_error_token(
    model,
    train_dataset,
    t,
    batch_size,
    is_weight_by_token,
    is_completion_only,
    mode="weak",
):
    def func(example):
        torch.cuda.empty_cache()
        input_ids = torch.tensor(example["input_ids"], dtype=torch.int32).to(
            model.device
        )
        attention_mask = torch.tensor(example["attention_mask"], dtype=torch.int8).to(
            model.device
        )
        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits.detach()
        pred_ids = torch.argmax(logits, dim=-1, keepdim=False)
        indices = example["end_index"] if is_completion_only else example["start_index"]
        pred_errors = []
        for i in range(logits.size(0)):
            pred_error = (
                pred_ids[i, indices[i] - 1 : -1].cpu().to(torch.int32).numpy()
                != np.array(example["sft_labels"][i], dtype=np.int32)
            ).tolist()
            if not is_weight_by_token:
                pred_error = np.mean(pred_error)
            pred_errors.append(pred_error)
        if mode == "weak":
            return {f"error_{t}": pred_errors}
        else:
            return {f"error_strong_{t}": pred_errors}

    train_dataset = train_dataset.map(
        func,
        batched=True,
        # Important, do not change
        batch_size=batch_size,
        num_proc=1,
        desc=f"Round {t}: predicting error on train/eval",
    )

    return train_dataset


def predict_logits_token(
    model, dataset, t, batch_size, is_completion_only, logits_top_k
):
    def func(example):
        torch.cuda.empty_cache()
        input_ids = torch.tensor(example["input_ids"], dtype=torch.int32).to(
            model.device
        )
        attention_mask = torch.tensor(example["attention_mask"], dtype=torch.int8).to(
            model.device
        )
        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits.detach()
        # Zero the non-label part
        indices = example["end_index"] if is_completion_only else example["start_index"]
        indices = torch.tensor(indices, dtype=torch.long, device=model.device) - 1
        # Zero the non-label part
        logits.masked_fill_(
            torch.arange(logits.size(1), device=model.device).view(1, -1, 1)
            < indices.view(-1, 1, 1),
            0.0,
        )
        logits[:, :, -1] = 0.0
        # Zero the non-topk part
        mask = torch.ones_like(logits, dtype=torch.bool, device=model.device)
        mask.scatter_(
            dim=-1,
            index=logits.topk(logits_top_k, dim=-1, largest=True, sorted=False).indices,
            value=False,
        )
        logits.masked_fill_(mask, 0.0)
        # Convert to sparse tensor and get indices and values
        logits = logits.to_sparse()
        indices, values, size = logits.indices(), logits.values(), logits.size()
        return {
            f"logits_indices_{t}": [indices.cpu().to(torch.int32)]
            + [[]] * (len(example["text"]) - 1),
            f"logits_values_{t}": [values.cpu().to(torch.float32)]
            + [[]] * (len(example["text"]) - 1),
            f"logits_size_{t}": [size] + [[]] * (len(example["text"]) - 1),
        }

    dataset = dataset.map(
        func,
        batched=True,
        # Important, do not change
        batch_size=batch_size,
        num_proc=1,
        desc=f"Round {t}: predicting logits on eval",
    )

    return dataset


def combine_pred_error_token(
    model,
    eval_dataset,
    t,
    adaboost_alphas,
    batch_size,
    is_weight_by_token,
    is_completion_only,
    is_combine_probs,
    is_top_k_pooling,
):
    adaboost_alphas = np.array(adaboost_alphas) / np.sum(adaboost_alphas)

    def func(example):
        torch.cuda.empty_cache()
        logits_list = [
            torch.sparse_coo_tensor(
                torch.tensor(
                    example[f"logits_indices_{inner_t}"][0],
                    dtype=torch.int32,
                    device=model.device,
                ),
                (
                    torch.tensor(
                        example[f"logits_values_{inner_t}"][0],
                        dtype=torch.float32,
                        device=model.device,
                    )
                    if not is_top_k_pooling
                    else torch.ones(
                        len(example[f"logits_values_{inner_t}"][0]),
                        dtype=torch.float32,
                        device=model.device,
                    )
                ),
                example[f"logits_size_{inner_t}"][0],
            )
            for inner_t in range(1, t + 1)
        ]

        if is_combine_probs:
            combined = adaboost_alphas[0] * torch.sparse.softmax(logits_list[0], dim=-1)
            for alpha, logits in zip(adaboost_alphas[1:], logits_list[1:]):
                combined += alpha * torch.sparse.softmax(logits, dim=-1)
            combined = combined.coalesce()
        else:
            combined = adaboost_alphas[0] * logits_list[0]
            for alpha, logits in zip(adaboost_alphas[1:], logits_list[1:]):
                combined += alpha * logits
            combined = combined.coalesce()

        pred_ids = combined.to_dense().argmax(dim=-1, keepdim=False)
        pred_ids = pred_ids.cpu().to(torch.int32).numpy()
        indices = example["end_index"] if is_completion_only else example["start_index"]
        pred_errors = []
        for i in range(len(example["text"])):
            pred_error = (
                pred_ids[i, indices[i] - 1 : -1]
                != np.array(example["sft_labels"][i], dtype=np.int32)
            ).tolist()
            if not is_weight_by_token:
                pred_error = np.mean(pred_error)
            pred_errors.append(pred_error)
        return {f"error_{t}": pred_errors}

    eval_dataset = eval_dataset.map(
        func,
        batched=True,
        batch_size=batch_size,
        # Important to set num_proc=1 to avoid deadlock
        num_proc=1,
        desc=f"Round {t}: combining logits and predict error on eval",
    )

    return eval_dataset


def predict_error_option(model, train_dataset, t, batch_size, mode="weak"):
    def func(example):
        torch.cuda.empty_cache()
        input_ids_flattened_list, lengths = flatten_nested_list(
            example["input_ids_list"]
        )
        attention_mask_flattened_list, _ = flatten_nested_list(
            example["attention_mask_list"]
        )
        end_index_flattened_list, _ = flatten_nested_list(example["end_index_list"])
        sft_options_flattened_list, _ = flatten_nested_list(example["sft_options_list"])
        input_ids = torch.tensor(input_ids_flattened_list, dtype=torch.int32).to(
            model.device
        )
        attention_mask = torch.tensor(
            attention_mask_flattened_list, dtype=torch.int8
        ).to(model.device)
        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits.detach()
        avg_log_likelihood_flattened_list = []
        for i in range(logits.size(0)):
            avg_log_lilelihood = logits[i, end_index_flattened_list[i] - 1 : -1, :]
            avg_log_lilelihood = avg_log_lilelihood[
                torch.arange(avg_log_lilelihood.size(0)), sft_options_flattened_list[i]
            ]
            avg_log_lilelihood = torch.mean(avg_log_lilelihood)
            avg_log_likelihood_flattened_list.append(
                avg_log_lilelihood.cpu().to(torch.float32).numpy()
            )
        avg_log_lilelihood_nested_list = pack_nested_list(
            avg_log_likelihood_flattened_list, lengths
        )
        option_targets = []
        option_errors = []
        for avg_log_likelihoods, target in zip(
            avg_log_lilelihood_nested_list, example["target"]
        ):
            option_targets.append(np.argmax(avg_log_likelihoods))
            option_errors.append(np.argmax(avg_log_likelihoods) != target)
        if mode == "weak":
            return {
                f"option_target_{t}": option_targets,
                f"option_error_{t}": option_errors,
            }
        else:
            return {
                f"option_error_strong_{t}": option_errors,
            }

    train_dataset = train_dataset.map(
        func,
        batched=True,
        # Important, do not change
        batch_size=batch_size,
        num_proc=1,
        desc=f"Round {t}: predicting options on train/eval",
    )

    return train_dataset


def predict_probs_option(model, dataset, t, batch_size):
    def func(example):
        torch.cuda.empty_cache()
        input_ids_flattened_list, lengths = flatten_nested_list(
            example["input_ids_list"]
        )
        attention_mask_flattened_list, _ = flatten_nested_list(
            example["attention_mask_list"]
        )
        end_index_flattened_list, _ = flatten_nested_list(example["end_index_list"])
        sft_options_flattened_list, _ = flatten_nested_list(example["sft_options_list"])
        input_ids = torch.tensor(input_ids_flattened_list, dtype=torch.int32).to(
            model.device
        )
        attention_mask = torch.tensor(
            attention_mask_flattened_list, dtype=torch.int8
        ).to(model.device)
        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).logits.detach()
        avg_log_likelihood_flattened_list = []
        for i in range(logits.size(0)):
            avg_log_lilelihood = logits[i, end_index_flattened_list[i] - 1 : -1, :]
            avg_log_lilelihood = avg_log_lilelihood[
                torch.arange(avg_log_lilelihood.size(0)), sft_options_flattened_list[i]
            ]
            avg_log_lilelihood = torch.mean(avg_log_lilelihood)
            avg_log_likelihood_flattened_list.append(
                avg_log_lilelihood.cpu().to(torch.float32)
            )
        avg_log_lilelihood_nested_list = pack_nested_list(
            avg_log_likelihood_flattened_list, lengths
        )
        option_probs_list = []
        for avg_log_likelihoods in avg_log_lilelihood_nested_list:
            option_probs_list.append(
                torch.softmax(
                    torch.tensor(avg_log_likelihoods, dtype=torch.float32), dim=0
                )
            )
        return {f"option_probs_{t}": option_probs_list}

    dataset = dataset.map(
        func,
        batched=True,
        # Important, do not change
        batch_size=batch_size,
        num_proc=1,
        desc=f"Round {t}: predicting logits of options on eval",
    )

    return dataset


def combine_pred_error_option(
    model,
    eval_dataset,
    t,
    adaboost_alphas,
    batch_size,
):
    adaboost_alphas = np.array(adaboost_alphas) / np.sum(adaboost_alphas)

    def func(example):
        torch.cuda.empty_cache()
        option_probs_list_of_dicts = [
            {
                i: torch.tensor(op, dtype=torch.float32, device=model.device)
                for i, op in enumerate(example[f"option_probs_{inner_t}"])
            }
            for inner_t in range(1, t + 1)
        ]
        option_probs_dict_of_lists = list_of_dicts_to_dict_of_lists(
            option_probs_list_of_dicts
        )
        option_targets = []
        option_errors = []
        for i in range(len(example["text"])):
            combined = sum(
                [
                    alpha * prob
                    for alpha, prob in zip(
                        adaboost_alphas, option_probs_dict_of_lists[i]
                    )
                ]
            )
            pred_option = combined.argmax().item()
            option_targets.append(pred_option)
            option_errors.append(pred_option != example["target"][i])
        return {
            f"option_target_{t}": option_targets,
            f"option_error_{t}": option_errors,
        }

    eval_dataset = eval_dataset.map(
        func,
        batched=True,
        batch_size=batch_size,
        # Important to set num_proc=1 to avoid deadlock
        num_proc=1,
        desc=f"Round {t}: combining logits and predict options on eval",
    )

    return eval_dataset


def predict_token_train(model, train_dataset, t, sargs):
    train_dataset = predict_error_token(
        model,
        train_dataset,
        t,
        sargs.pred_batch_size,
        sargs.is_weight_by_token,
        sargs.is_completion_only,
    )
    if t >= 1:
        train_dataset = update_adaboost_weight(
            train_dataset, t, sargs.num_proc, sargs.is_weight_by_token
        )
    return train_dataset


def predict_token_eval(model, train_dataset, eval_dataset, t, sargs, mode="weak"):
    if t <= 0 or mode == "strong":
        eval_dataset = predict_error_token(
            model,
            eval_dataset,
            t,
            sargs.pred_batch_size,
            sargs.is_weight_by_token,
            sargs.is_completion_only,
            mode,
        )
        return eval_dataset
    eval_dataset = predict_logits_token(
        model,
        eval_dataset,
        t,
        sargs.pred_batch_size,
        sargs.is_completion_only,
        sargs.logits_top_k,
    )
    adaboost_alphas = calculate_adaboost_alphas(
        train_dataset, t, sargs.is_weight_by_token
    )
    eval_dataset = combine_pred_error_token(
        model,
        eval_dataset,
        t,
        adaboost_alphas,
        sargs.pred_batch_size,
        sargs.is_weight_by_token,
        sargs.is_completion_only,
        sargs.is_combine_probs,
        sargs.is_top_k_pooling,
    )
    return eval_dataset


def predict_option_eval(model, train_dataset, eval_dataset, t, sargs, mode="weak"):
    if t <= 0 or mode == "strong":
        eval_dataset = predict_error_option(
            model,
            eval_dataset,
            t,
            sargs.pred_batch_size,
            mode,
        )
        return eval_dataset
    eval_dataset = predict_probs_option(
        model,
        eval_dataset,
        t,
        sargs.pred_batch_size,
    )
    adaboost_alphas = calculate_adaboost_alphas(
        train_dataset, t, sargs.is_weight_by_token
    )
    eval_dataset = combine_pred_error_option(
        model,
        eval_dataset,
        t,
        adaboost_alphas,
        sargs.pred_batch_size,
    )
    return eval_dataset
