from collections import defaultdict
import matplotlib.pyplot as plt
from typing import List, Dict
from itertools import islice

import numpy as np
import logging
import random
import wandb
import copy
import json
import csv
import os
import sys


logger = logging.getLogger()


class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


def chunks(items, chunk_size):
    iterator = iter(items)
    while chunk := list(islice(iterator, chunk_size)):
        yield list(map(list, zip(*chunk)))


def log_config(arguments):
    logger.info("Logging used config:")
    logger.info("-" * 50)
    for argument, value in arguments.items():
        logger.info("{}: {}".format(argument, value))
    logger.info("-" * 50)


def log_prompt_templates(prompt_templates, k_shot, mask_token=None):
    # table = wandb.Table(
    #     columns=["Prompt Template No", "Utterance", "Response", "Implicature", "Prompt"]
    # )
    prompt_example = {
        "utterance": "Are you going to the party on Friday?",
        "response": "Is the pope catholic?",
        "implicature": "yes",
    }
    prompt_examples = [prompt_example for _ in range(k_shot)]
    test_example = {
        "utterance": "Have you found him yet?",
        "response": "We're still looking.",
        "implicature": "no",
    }
    for i, prompt_template in enumerate(prompt_templates):
        logger.info("Prompt variation %d:" % i)
        if prompt_template.prompt_instruction_set:
            prompt = prompt_template.prompt(
                test_example, is_false_example=False, prompt_examples=prompt_examples, mask_token=mask_token
            )
            logger.info("\n" + prompt)
        else:
            prompt = prompt_template.prompt(test_example, is_false_example=False, mask_token=mask_token)
            logger.info("\n" + prompt)
        # table.add_data(
        #     i,
        #     test_example["utterance"],
        #     test_example["response"],
        #     test_example["implicature"],
        #     prompt,
        # )
    # wandb.log({"prompt variations": table})


def get_correct_document_index(example: Dict[str, str]):
    return example["options"].index(example["meaning"])


def completion_is_correct(text: str, correct_document_idx: int) -> bool:
    return str(correct_document_idx) in text


def similarity_ranking_is_correct(
    scores_per_document: List[float],
    correct_document_idx: int,
    doc_is_similar_score=200,
) -> bool:
    """
    Returns true if the scores in scores_per_document are below doc_is_similar_score for the indices that aren't the
    correct document and above for the index that is the correct document. Otherwise returns false.
    :param scores_per_document: a list of semantic similarity scores per document
    :param correct_document_idx: the index of the correct document for the query
    :param doc_is_similar_score: min score for a doc to be similar
    :return: boolean indicating whether the ranking is correct
    """
    for idx, score in enumerate(scores_per_document):
        if idx == correct_document_idx:
            if score < doc_is_similar_score:
                return False
        else:
            if score >= doc_is_similar_score:
                return False
    return True


class MultiMetric:
    """
    Keeps multiple metrics and reports mean/std. dev. between them as well as individual metrics.
    """

    def __init__(self, num_metrics: int):
        """
        :param num_metrics: how many metrics to keep
        :param bigger_is_better: are the metrics such that a bigger score is better or not
        """
        self._metrics = [
            AccuracyMetric(name="Accuracy Metric") for _ in range(num_metrics)
        ]

    def update(self, index_to_update: int, correct: int, **kwargs):
        """

        :param index_to_update: which metric to update
        :param correct: whether the example was correctly predicted according to the task
        :return: whether the example was correctly predicted according to the task
        """
        correct = self._metrics[index_to_update].update(correct=correct)
        return correct

    def get_accuracy(self, index: int):
        return self._metrics[index].get_accuracy

    def get_mean_and_std(self):
        accuracies = [metric.get_accuracy for metric in self._metrics]
        return float(np.mean(accuracies)), float(np.std(accuracies))

    def reset(self):
        for metric in self._metrics:
            metric.reset()

    def __str__(self):
        mean, std = self.get_mean_and_std()
        full_str = "Mean accuracy: %.2f\nStd. Dev: %.2f\nAccuracy per metric:\n" % (
            mean,
            std,
        )
        for metric in self._metrics:
            full_str += str(metric)
            full_str += "\n"
        return full_str


class AccuracyMetric:
    """
    Keeps accuracy metric
    """

    def __init__(self, name: str):
        self._correct = 0
        self._total = 0
        self._name = name

    def reset(self):
        self._correct = 0
        self._total = 0

    @property
    def get_accuracy(self):
        assert self._total, "Can't get accuracy with total=0."
        return (self._correct / self._total) * 100.0

    def update(self, correct):
        assert 0 <= correct <= 1, (
            "correct must be between 0 and 1 but is %.2f" % correct
        )
        self._correct += int(correct)
        self._total += 1
        return correct

    def __str__(self):
        metric_str = "%s accuracy: %.2f\n" % (self._name, self.get_accuracy)
        return metric_str


class WandbTable:
    def __init__(self, particularised: bool):
        if particularised:
            self._result_table = wandb.Table(
                columns=[
                    "Utterance",
                    "Response",
                    "Implicature",
                    "Model ID",
                    "Prompt Template no.",
                    "Implicature Correct score",
                    "Implicature False score",
                    "Implicature Correct",
                ]
            )
            self._text_table_columns = [
                "Utterance",
                "Response",
                "Implicature",
                "Prompt Template no.",
            ]
            self._text_table = None
        else:
            raise NotImplementedError()

    def _setup_text_table(self, scored_texts):
        extra_columns = list(scored_texts.keys())
        self._text_table_columns.extend(extra_columns)
        self._text_table = wandb.Table(columns=self._text_table_columns)

    def add(
        self,
        original_example,
        model_id,
        prompt_template_n,
        implicature_result,
        implicature_correct,
    ):
        if not self._text_table:
            self._setup_text_table(
                implicature_result["scored_texts"]
            )
        self._result_table.add_data(
            original_example["utterance"],
            original_example["response"],
            original_example["implicature"],
            model_id,
            prompt_template_n,
            implicature_result["correct_score"],
            implicature_result["false_score"],
            implicature_correct,
        )
        self._text_table.add_data(
            original_example["utterance"],
            original_example["response"],
            original_example["implicature"],
            prompt_template_n,
            *implicature_result["scored_texts"].values(),
        )

    def log_table(self):
        wandb.log({"predictions": self._result_table})
        wandb.log({"scored texts": self._text_table})


def plot_grouped_barchart(
    labels,
    group_one_values,
    group_one_label,
    chart_title,
    group_one_errors=None,
    use_wandb=False
):

    x = np.arange(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    fig, ax = plt.subplots()
    plt.ylim(0, 100)
    ax.bar(x - width / 2, group_one_values, width, label=group_one_label)
    if group_one_errors:
        ax.errorbar(
            x - width / 2,
            group_one_values,
            label=f"{group_one_label} std. over prompts",
            yerr=group_one_errors,
            fmt="o",
            color="darkblue",
        )
    ax.legend()

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel("Accuracy")
    ax.set_title(chart_title)
    ax.set_xticks(x, labels)
    ax.legend()
    locs, xlabels = plt.xticks()
    plt.setp(xlabels, rotation=45)
    fig.tight_layout()
    plt.savefig("tmpfig.png", dpi=300, bbox_inches="tight")
    if use_wandb:
        wandb.log({chart_title: wandb.Image("tmpfig.png")})
    os.remove("tmpfig.png")


def save_results_to_file(
        num_prompt_templates,
        models,
        all_prediction_results,
        implicature_data,
        write_data_to,
        write_results_to,
        arguments):
    all_results = defaultdict(lambda: defaultdict(lambda: defaultdict()))
    metric_labels = ["implicature_metrics"]
    columns = ["model"]
    for i in range(num_prompt_templates):
        for metric_label in metric_labels:
            columns.append(f"{i}_{metric_label}")
    for model_d in models:
        mean, std = model_d["implicature_metrics"].get_mean_and_std()
        all_results[model_d["model_id"]]["mean_accuracy"] = mean
        all_results[model_d["model_id"]]["std"] = std

        for i in range(num_prompt_templates):
            for metric_label in metric_labels:
                accuracy = model_d[metric_label].get_accuracy(i)
                all_results[model_d["model_id"]][f"prompt_template_{i}"][
                    metric_label
                ] = accuracy

    # Write the data used for the evaluation to json files.
    with open(write_data_to + "_implicature.json", "w") as infile:
        json.dump(implicature_data, infile, indent=4)

    serializable_arguments = dict(arguments)
    # serializable_arguments["wandb"] = dict(serializable_arguments["wandb"])
    all_results["arguments"] = serializable_arguments
    all_results["predictions"] = all_prediction_results
    with open(write_results_to, "w") as infile:
        json.dump(all_results, infile, indent=4)


def log_all_results(
    num_prompt_templates,
    models,
    use_wandb=False
):
    implicature_chart = {"labels": [], "values": [], "errors": []}

    metric_labels = ["implicature_metrics"]
    columns = ["model"]
    for i in range(num_prompt_templates):
        for metric_label in metric_labels:
            columns.append(f"{i}_{metric_label}")
    columns.extend(["mean impl", "std impl"])
    if use_wandb:
        table = wandb.Table(columns=columns)

    all_results = defaultdict(lambda: defaultdict(lambda: defaultdict()))
    for model_d in models:
        logger.info(f"Scores for model card {model_d['model_id']}")
        logger.info("Implicature score:")
        logger.info(model_d["implicature_metrics"])
        mean, std = model_d["implicature_metrics"].get_mean_and_std()
        implicature_chart["labels"].append(model_d["model_id"])
        implicature_chart["values"].append(mean)
        implicature_chart["errors"].append(std)
        all_results[model_d["model_id"]]["mean_accuracy"] = mean
        all_results[model_d["model_id"]]["std"] = std

        prompt_template_accuracies = []
        values_per_group = {metric_label: [] for metric_label in metric_labels}
        for i in range(num_prompt_templates):
            for metric_label in metric_labels:
                accuracy = model_d[metric_label].get_accuracy(i)
                prompt_template_accuracies.append(accuracy)
                values_per_group[metric_label].append(accuracy)
                all_results[model_d["model_id"]][f"prompt_template_{i}"][
                    metric_label
                ] = accuracy
        if use_wandb:
            table.add_data(
                model_d["model_id"],
                *prompt_template_accuracies,
                mean,
                std
            )

        plot_grouped_barchart(
            [f"Prompt template {i}" for i in range(num_prompt_templates)],
            values_per_group["implicature_metrics"],
            "Implicature",
            f"Implicature results for {model_d['model_id']}",
        )

    plot_grouped_barchart(
        implicature_chart["labels"],
        implicature_chart["values"],
        "Implicature",
        "Implicature results per model",
        implicature_chart["errors"],
    )
    if use_wandb:
        wandb.log({"Results per template": table})


def _filter_examples(input_line: Dict[str, str]) -> bool:
    """
    Takes an input_line from the csv file and filters all examples where the implicature is not a simple yes or no.
    :param input_line: a line read from a csv file with data
    :return:
    """
    if not input_line:
        return False
    if "yes" in input_line["Implicature"].lower()[:5]:
        return True
    elif "no" in input_line["Implicature"].lower()[:4]:
        return True
    else:
        return False


def make_dataset_splits(
    input_data_path: str,
    dev_output_path: str,
    test_output_path: str,
    num_test_examples=600,
):
    assert os.path.exists(
        input_data_path
    ), "No input data file found at: %s\n" "Current working direction: %s" % (
        input_data_path,
        os.getcwd(),
    )
    fieldnames = []
    with open(input_data_path, newline="") as csvfile:
        reader = csv.DictReader(csvfile)
        fieldnames.extend(reader.fieldnames)
        dev_rows = []
        test_rows = []
        filtered_examples = []
        for row in reader:
            if _filter_examples(row):
                filtered_examples.append(row)
        random.shuffle(filtered_examples)
        for example in filtered_examples:
            if len(test_rows) == num_test_examples:
                dev_rows.append(example)
            else:
                test_rows.append(example)
        assert (
            len(test_rows) == num_test_examples
        ), f"Got {len(test_rows)} test examples but wanted {num_test_examples}"
    with open(dev_output_path, "w", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for row in dev_rows:
            writer.writerow(row)
    with open(test_output_path, "w", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for row in test_rows:
            writer.writerow(row)
    print(f"Wrote {len(test_rows)} test and {len(dev_rows)} dev examples.")
    return


def get_negative_binary_example(example):
    """
    Creates a false example for a binary implicature example.
    :param example:
    :return: the same dict as the input except for the implicature is negated (yes to no and vice-versa)
    """
    if example["implicature"] == "yes":
        false_implicature = "no"
    elif example["implicature"] == "no":
        false_implicature = "yes"
    else:
        raise ValueError("Unknown implicature %s" % example["implicature"])
    false_example = copy.deepcopy(example)
    false_example["implicature"] = false_implicature
    return false_example


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)


if __name__ == "__main__":
    input_file_path = "data/conversational_implicatures.csv"
    dev_file_path = "data/dev_conversational_implicatures.csv"
    test_file_path = "data/test_conversational_implicatures.csv"
    make_dataset_splits(input_file_path, dev_file_path, test_file_path)
