import re
import sys
import json
import logging
from collections import OrderedDict

from termcolor import colored

from .outline_scheduler import OutlineScheduler
from sot.utils import _print_to_streams


class OutlineBatchScheduler(OutlineScheduler):
    def set_model(self, model):
        self._model = model

    def print_info(self):
        super().print_info()
        logging.info(
            colored("OutlineScheduler *outline prompt*: ", "magenta")
            + f"'''{self._outline_prompt}'''"
        )
        logging.info(
            colored("OutlineScheduler *point prompt*: ", "magenta")
            + f"'''{self._point_prompt}'''"
        )

    @staticmethod
    def command_line_parser():
        parser = super(OutlineScheduler, OutlineScheduler).command_line_parser()
        parser.add_argument(
            "--prompt-file",
            type=str,
            help=(
                "The path of the JSON file containing `outline_prompt` and"
                " `point_prompt`."
            ),
            default=None,
        )
        parser.add_argument(
            "--outline-prompt", type=str, default=None
        )  # , required=True)
        parser.add_argument(
            "--point-prompt", type=str, default=None
        )  # , required=True)
        return parser

    def _get_response_stream(self, request):
        outline_request = self.format_outline_prompt(request=request)
        for outputs in self._model.get_response([outline_request], stream=False):
            outputs["stage"] = "outline"
            yield outputs
        outline = outputs["text"]
        if outline_request[1]:
            outline = outline_request[1] + outline

        # Detect number of points.
        # TODO: make it more robust.
        re_result = re.findall(r"(\d+)\.\s?([\s\S]+?)(?=\n|\n*$)", outline)
        if len(re_result) > 0:
            points, point_outlines = zip(*re_result)
        else:
            points, point_outlines = [], []

        num_points = len(points)
        if num_points > 0:
            # Filter to get unique point indexes
            points_filtered = []
            point_outlines_filtered = []
            points_set = set([])
            for i in range(len(points)):
                if points[i] not in points_set:
                    points_set.add(points[i])
                    points_filtered.append(points[i])
                    point_outlines_filtered.append(point_outlines[i])
            points = points_filtered
            point_outlines = point_outlines_filtered

            sub_requests = [
                self.format_point_prompt(
                    request=request,
                    point=point,
                    outline=outline,
                    point_outline=point_outline,
                )
                for point, point_outline in zip(points, point_outlines)
            ]

            for i_stream_out, outputs in enumerate(
                self._model.get_response(sub_requests, batch=True, stream=True)
            ):
                outputs["stage"] = "expand"
                point_responses = [point_resp.strip() for point_resp in outputs["text"]]
                point_partial_and_responses = [
                    sub_request[1] + " " + point_resp if sub_request[1] else point_resp
                    for sub_request, point_resp in zip(sub_requests, point_responses)
                ]
                # point_text.replace("\n", "") for point_text in outputs["text"]
                outputs["text"] = "\n".join(point_partial_and_responses)
                yield outputs
            content = outputs["text"]
        else:
            content = ""
        yield {
            "stage": "summarize",
            "request": request,
            "response": "",
            "text": content,
            "outline": outline,
            "contents": content,
            "points": points,
            "point_outlines": point_outlines,
        }

    def get_response(self, request, stream=False):
        if stream:
            return self._get_response_stream(request)

        outline_request = self.format_outline_prompt(request=request)
        outline_response = self._model.get_response([outline_request])[0]
        outline = outline_response["text"]
        outline_time = outline_response["time"]
        if outline_request[1]:
            outline = outline_request[1] + outline

        re_result = re.findall(r"(\d+)\.\s?([\s\S]+?)(?=\n|\n*$)", outline)
        if len(re_result) > 0:
            points, point_outlines = zip(*re_result)
        else:
            points, point_outlines = [], []

        if len(points) > 0:
            sub_requests = [
                self.format_point_prompt(
                    request=request,
                    point=point,
                    outline=outline,
                    point_outline=point_outline,
                )
                for point, point_outline in zip(points, point_outlines)
            ]
            outputs = self._model.get_response(sub_requests, batch=True)
            point_time = outputs["time"]
            contents = [
                sub_request[1] + " " + point_resp if sub_request[1] else point_resp
                for sub_request, point_resp in zip(sub_requests, outputs["text"])
            ]
        else:
            contents = []

        content = "\n".join(contents)

        return {
            "request": request,
            "response": content,
            "outline": outline,
            "contents": contents,
            "points": points,
            "point_outlines": point_outlines,
            "outline_time": outline_time,
            "point_time": point_time,
        }
