import random
import os.path as osp

import json

from embodied_cd.common.dataset_utils import PromptTemplate


class FewShotMixIn:
    def load_few_shot_pool(self, dataset_dir):
        dataset_path = osp.join(dataset_dir, "full_dataset.jsonl")
        dataset = self._load_jsonl_dataset(dataset_path)

        if self.name == "fewshot" or self.name == "saycan" or self.name == "gpt":
            self.few_shot_pool = self._build_few_shot_pool(dataset)
        elif self.name == "react" or self.name == "reflexion" or self.name == "react_saycan":
            self.few_shot_pool = self._build_react_few_shot_pool(dataset)
        elif self.name == "saycanpay":
            self.few_shot_pool = self._build_saycanpay_few_shot_pool(dataset)

    def sample_few_shot_prompt(self, goal, k):
        random.shuffle(self.few_shot_pool)

        goal = goal.lower()
        few_shot_goals = [self._extract_goal(demo) for demo in self.few_shot_pool]

        def jaccard_similarity(s1, s2):
            s1, s2 = set(s1.split()), set(s2.split())
            return len(s1.intersection(s2)) / len(s1.union(s2))

        few_shot_pool = sorted(
            zip(few_shot_goals, self.few_shot_pool),
            key=lambda x: jaccard_similarity(goal, x[0]),
            reverse=True,
        )
        few_shot_pool = [demo for _, demo in few_shot_pool]

        few_shot_prompt = "\n\n".join(few_shot_pool[:k])
        return few_shot_prompt

    def _load_jsonl_dataset(self, dataset_path):
        dataset = []
        with open(dataset_path) as f:
            states, actions, thinks = [], [], []
            for line in f:
                json_line = json.loads(line)

                if actions and json_line["history"] == "No action history.":
                    dataset.append(
                        {
                            "goal": json_line["instruction"],
                            "states": states,
                            "actions": actions,
                            "thinks": thinks,
                        }
                    )
                    states, actions = [], []

                states.append(json_line["state"])
                actions.append(json_line["action"])
                thinks.append(json_line["think"])

            dataset.append(
                {
                    "goal": json_line["instruction"],
                    "states": states,
                    "actions": actions,
                    "thinks": thinks,
                }
            )
        return dataset

    def _extract_goal(self, demo):
        return demo.split(".\n")[0].split("Your task is to: ")[1].strip(" .").lower()

    def _build_few_shot_pool(self, dataset):
        few_shot_pool = []

        for sample in dataset:
            demo = f"Your task is to: {sample['goal']}.\n"
            for state, action in zip(sample["states"], sample["actions"]):
                demo += f"{PromptTemplate.preprocess(state)}\n"
                demo += f"> {action}\n"
            few_shot_pool.append(demo.strip("\n"))

        return few_shot_pool

    def _build_react_few_shot_pool(self, dataset):
        few_shot_pool = []

        for sample in dataset:
            demo = f"Your task is to: {sample['goal']}.\n"
            for state, think, action in zip(
                sample["states"], sample["thinks"], sample["actions"]
            ):
                demo += f"{PromptTemplate.preprocess(state)}\n"
                demo += f"> think: {think}\n"
                demo += f"OK.\n"
                demo += f"> {action}\n"
            few_shot_pool.append(demo.strip("\n"))

        return few_shot_pool

    def _build_saycanpay_few_shot_pool(self, dataset):
        few_shot_pool = []

        for sample in dataset:
            demo = f"[Goal] Turn on tv [Initial State] {sample['states'][0]}"
            for i, action in enumerate(sample["actions"]):
                demo += f" [Step {i}] {action}"
            demo += f" [Step {len(sample['actions'])}] done task"
            few_shot_pool.append(demo.strip("\n"))

        return few_shot_pool
