from pathlib import Path
from collections import Counter, defaultdict
import json
import copy
import re
import pandas as pd
from tqdm.auto import tqdm
import datetime
from typing import Callable, List
from manifest import Manifest


class InputOutputPrompt:
    def __init__(self,
        input_formatter: Callable,
        output_formatter: Callable,
        required_keys: List,
        input_output_sep: str = "\n",
        example_sep: str = "\n\n",
        instruction: str = ""
    ):
        self.input_formatter = input_formatter
        self.output_formatter = output_formatter
        self.required_keys = required_keys
        self.input_output_sep = input_output_sep
        self.example_sep = example_sep
        self.instruction = instruction

    def __call__(self, input_output_pairs: pd.DataFrame):
        examples = []
        for _, example in input_output_pairs.iterrows():
            examples.append(f"{self.input_formatter(example)}{self.input_output_sep}{self.output_formatter(example)}")
        if examples:
            input_str = self.example_sep.join(examples)
            res = f"{self.instruction}{input_str}"
        else:
            res = f"{self.instruction}".rstrip()
        return res
    
    def __repr__(self):
        dummy_ex = pd.DataFrame([{k: f"<{k.upper()}>" for k in self.required_keys}])
        st = self(dummy_ex)
        return st


def prefix_formatter(ex_keys: List[str], prefix: str, error_on_empty: bool = True) -> str:
    def full_prefix_formatter(ex: pd.Series):
        for k in ex_keys:
            if k in ex:
                return f"{prefix} {getattr(ex, k)}"
        if error_on_empty:
            raise ValueError(f"Example {ex} has no value for any of the keys {ex_keys}")
        else:
            return f"{prefix}"
    return full_prefix_formatter

class TogetherDataCollectior:
    # https://github.com/togethercomputer/open-models-api/blob/main/dry_run/example_requests.jsonl
    """
    for f in ANLI_final.py CB_final.py DBPedia_final.py drop_final.py MR_final.py SST2_final.py WIC_final.py WSC_final.py; do
        python3 $f --num_run -1 --together_model opt-175b --together_output_file ../together_runs/data/500_ex_083022.jsonl;
        python3 $f --num_run -1 --together_model bloom --together_output_file ../together_runs/data/500_ex_083022.jsonl
    done;
    """

    def __init__(self, model_name, task_name, together_output_file, together_input_files: str = None):
        self.model_name = model_name
        self.task_name = task_name
        self.output_file = Path(together_output_file)
        self.output_file.parent.mkdir(parents=True, exist_ok=True)
        self.run_name = ""
        self.today = datetime.datetime.today().strftime("%m%d%Y")
        self.num_hits = defaultdict(int)
        self.total = defaultdict(int)

        self.previous_results = defaultdict(lambda: defaultdict(dict))
        if together_input_files:
            c = 0
            for input_file in together_input_files:
                with open(input_file) as f:
                    for line in f:
                        log = json.loads(line)
                        request = log["request"]
                        if request["model"] != model_name:
                            continue
                        # hack to fix previous bug
                        if request["run_name"].startswith("together_"):
                            request["run_name"] = request["run_name"].replace("together_", f"{self.model_name}_")
                        # Remove trailing date
                        if len(request["run_name"].split("_")) > 2 and re.match(r"\d+", request["run_name"].split("_")[-1]):
                            request["run_name"] = request["run_name"].rsplit("_", 1)[0] + "_" + self.today
                        self.previous_results[request["run_name"]][request["prompt"].strip()] = log
                        c += 1
            print(f"Loaded {c} previous results from {together_input_files}")

        self.data_collection = defaultdict(dict)
        self.data_collection_for_saving = defaultdict(lambda: defaultdict(list))

    def set_run_name(self, run_name):
        self.run_name = run_name

    def run(
        self,
        prompt,
        gold_choices=None,
        max_tokens=0,
        stop_token=None,
        overwrite_cache=False,
    ):
        log = {
            "request_type": "language-model-inference",
            "model": self.model_name,
            "max_tokens": max_tokens,
            "prompt": prompt.strip(),
            "run_name": self.run_name,
            "task_name": self.task_name,
            "n": 1,
            "temperature": 0,
            "model_result": "@TOGETHER_PLACEHOLDER@",
        }
        self.total[self.run_name] += 1
        if prompt.strip() in self.data_collection[self.run_name]:
            assert self.data_collection[self.run_name][prompt.strip()]["max_tokens"] == max_tokens
        # See if we have an answer from a previous run
        if prompt.strip() in self.previous_results[self.run_name]:
            # import ipdb; ipdb.set_trace()
            self.num_hits[self.run_name] += 1
            model_result = self.previous_results[self.run_name][prompt.strip()]["result"]["choices"][0]["text"]
            model_result = model_result.lstrip("</s>").strip()
            if stop_token:
                model_result = model_result.split(stop_token)[0]
            log["model_result"] = model_result
        self.data_collection[self.run_name][prompt.strip()] = log
        return log["model_result"]

    def add_golds(self, run_name, expt_log, dataset="test"):
        """Add gold labels and collect logs for saving"""
        for idx, log in expt_log.items():
            if "base_prompt" in log:
                prompts = [[log["base_prompt"].strip()]]
            else:
                prompt_list = log["prompts"]
                # List of size num boosts where each boosts is size number steps in decomp
                assert isinstance(prompt_list, list), print(prompt_list)
                assert isinstance(prompt_list[0], list), print(prompt_list)
                assert isinstance(prompt_list[0][0], str), print(prompt_list)
                # Take first step of decomps
                prompts = [[p.strip() for p in prompt_boosts] for prompt_boosts in log["prompts"]]
                for prompt_boosts in prompts:
                    for p in prompt_boosts:
                        # Checking that no format string is in p
                        if "{" in p:
                            print("INSIDE {")
                            print(p)
                            inside = p.split("{")[1].split("}")[0]
                            if ":" in inside and inside.split(":")[1] == "":
                                raise ValueError(f"Found empty format string in {p}")
                            if re.match(r"[a-zA-Z]+", inside):
                                raise ValueError(f"Found format string in {p}")
            for pb_i, prompt_boosts in enumerate(prompts):
                # Iterate over decomp steps
                for pd_i, prompt in enumerate(prompt_boosts):
                    if len(prompts) == 1:
                        boost_idx = -1
                    else:
                        boost_idx = pb_i
                    if len(prompt_boosts) == 1:
                        decomp_idx = -1
                    else:
                        decomp_idx = pd_i
                    if prompt not in self.data_collection[run_name]:
                        import ipdb; ipdb.set_trace()
                        raise ValueError(f"Prompt {prompt} not found in run {run_name}")
                    # if prompt in self.data_collection_for_saving[run_name]:
                    #     print(f"Prompt {prompt} already has gold in run {run_name}")
                    # Add gold found prompts to data collection
                    collected_log = copy.deepcopy(self.data_collection[run_name][prompt])
                    # Only save down those that do not have answers but are not built from "TOGETHER PLACEHOLDERS"
                    if "@TOGETHER_PLACEHOLDER@" not in prompt and collected_log["model_result"] == "@TOGETHER_PLACEHOLDER@":
                        self.data_collection_for_saving[run_name][prompt].append(collected_log)
                        self.data_collection_for_saving[run_name][prompt][-1]["gold"] = log["gold"]
                        self.data_collection_for_saving[run_name][prompt][-1]["ind"] = log["ind"]
                        self.data_collection_for_saving[run_name][prompt][-1]["boost_idx"] = boost_idx
                        self.data_collection_for_saving[run_name][prompt][-1]["decomp_idx"] = decomp_idx
                        self.data_collection_for_saving[run_name][prompt][-1]["dataset"] = dataset
    
    def save(self):
        for r in self.total.keys():
            print(f"Run {r} has {self.num_hits[r]} hits out of {self.total[r]} = {self.num_hits[r] / self.total[r]:.2f}%")
        c = 0
        if len(self.data_collection_for_saving) > 0:
            with self.output_file.open("w") as f:
                for _, run_data in self.data_collection_for_saving.items():
                    for _, logs in run_data.items():
                        for log in logs:
                            assert "@TOGETHER_PLACEHOLDER@" not in log["prompt"]
                            c += 1
                            f.write(json.dumps(log) + "\n")
        print(f"Saved {c} Together examples to {self.output_file}")


def get_manifest_session(
    client_name="huggingface",
    client_engine=None,
    client_connection="http://127.0.0.1:5000",
    cache_connection=None,
    temperature=0,
    top_p=1.0,
):
    if client_name == "huggingface" and temperature == 0:
        params = {
            "temperature": 0.001,
            "do_sample": False,
            "top_p": top_p,
        }
    elif client_name in {"openai", "ai21"}:
        params = {
            "temperature": temperature,
            "top_p": top_p,
            "engine": client_engine,
        }
    else:
        raise ValueError(f"{client_name} is not a valid client name")
    manifest = Manifest(
        client_name=client_name,
        client_connection=client_connection,
        cache_name="sqlite",
        cache_connection=cache_connection,
        session_id=None,
        **params,
    )
    params = manifest.client.get_model_params()
    model_name = params["model_name"]
    if "engine" in params:
        model_name += f"_{params['engine']}"
    return manifest, model_name


def get_response(
    prompt,
    manifest,
    overwrite=False,
    max_toks=10,
    stop_token=None,
    gold_choices=[],
    verbose=False,
):
    prompt = prompt.strip()
    if gold_choices:
        gold_choices = [" " + g.strip() for g in gold_choices]
        response_obj = manifest.run(
            prompt, gold_choices=gold_choices, overwrite_cache=overwrite, return_response=True
        )
        response_obj = response_obj.get_json_response()["choices"][0]
        log_prob = response_obj["text_logprob"]
        response = response_obj["text"]
    else:
        response = manifest.run(
            prompt,
            max_tokens=max_toks,
            stop_token=stop_token,
            overwrite_cache=overwrite,
        )
        log_prob = None
    if verbose:
        print("\n***Prompt***\n", prompt)
        print("\n***Response***\n", response)
    if log_prob:
        return response, log_prob
    return response


def save_log(task_name, expt_name, log, final_run_dir="/home/final_runs"):
    final_run_dir = Path(final_run_dir)
    output_fpath = final_run_dir / task_name
    output_fpath.mkdir(parents=True, exist_ok=True)

    print("Saving to", output_fpath / f"{expt_name}.json")
    assert all(a in list(log.values())[0].keys() for a in ["ind","example","pred","gold"])
    with open(output_fpath / f"{expt_name}.json", "w") as f:
        json.dump(log, f)


def text_f1(preds, golds):
    """Compute average F1 of text spans.
    Taken from Squad without prob threshold for no answer.
    """
    total_f1 = 0
    for pred, gold in zip(preds, golds):
        pred_toks = pred.split()
        gold_toks = gold.split()
        common = Counter(pred_toks) & Counter(gold_toks)
        num_same = sum(common.values())
        if len(gold_toks) == 0 or len(pred_toks) == 0:
            # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
            total_f1 += int(gold_toks == pred_toks)
        elif num_same == 0:
            total_f1 += 0
        else:
            precision = 1.0 * num_same / len(pred_toks)
            recall = 1.0 * num_same / len(gold_toks)
            f1 = (2 * precision * recall) / (precision + recall)
            total_f1 += f1
    f1_avg = total_f1 / len(golds)
    return f1_avg

def accuracy_span_overlap(preds, golds):
    correct = 0
    for pred, gold in zip(preds, golds):
        found = False
        for p in pred:
            for g in gold:
                if len(p) < len(g):
                    if p.lower() in g.lower():
                        found = True
                        break
                else:
                    if  g.lower() in p.lower():
                        found = True
                        break
        if found: correct += 1
    return correct / len(preds)


