# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Iterator, List

import json5

from qwen_agent import Agent
from qwen_agent.agents.assistant import Assistant
from qwen_agent.agents.writing import ExpandWriting, OutlineWriting
from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message

default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}"""


def is_roman_numeral(s):
    pattern = r'^(I|V|X|L|C|D|M)+'
    match = re.match(pattern, s)
    return match is not None


class WriteFromScratch(Agent):

    def _run(self, messages: List[Message], knowledge: str = '', lang: str = 'en') -> Iterator[List[Message]]:

        response = [Message(ASSISTANT, f'>\n> Use Default plans: \n{default_plan}')]
        yield response
        res_plans = json5.loads(default_plan)

        summ = ''
        outline = ''
        for plan_id in sorted(res_plans.keys()):
            plan = res_plans[plan_id]
            if plan == 'summarize':
                response.append(Message(ASSISTANT, '>\n> Summarize Browse Content: \n'))
                yield response

                if lang == 'zh':
                    user_request = '总结参考资料的主要内容'
                elif lang == 'en':
                    user_request = 'Summarize the main content of reference materials.'
                else:
                    raise NotImplementedError
                sum_agent = Assistant(llm=self.llm)
                res_sum = sum_agent.run(messages=[Message(USER, user_request)], knowledge=knowledge, lang=lang)
                chunk = None
                for chunk in res_sum:
                    yield response + chunk
                if chunk:
                    response.extend(chunk)
                    summ = chunk[-1][CONTENT]
            elif plan == 'outline':
                response.append(Message(ASSISTANT, '>\n> Generate Outline: \n'))
                yield response

                otl_agent = OutlineWriting(llm=self.llm)
                res_otl = otl_agent.run(messages=messages, knowledge=summ, lang=lang)
                chunk = None
                for chunk in res_otl:
                    yield response + chunk
                if chunk:
                    response.extend(chunk)
                    outline = chunk[-1][CONTENT]
            elif plan == 'expand':
                response.append(Message(ASSISTANT, '>\n> Writing Text: \n'))
                yield response

                outline_list_all = outline.split('\n')
                outline_list = []
                for x in outline_list_all:
                    if is_roman_numeral(x):
                        outline_list.append(x)

                otl_num = len(outline_list)
                for i, v in enumerate(outline_list):
                    response.append(Message(ASSISTANT, '>\n# '))
                    yield response

                    index = i + 1
                    capture = v.strip()
                    capture_later = ''
                    if i < otl_num - 1:
                        capture_later = outline_list[i + 1].strip()
                    exp_agent = ExpandWriting(llm=self.llm)
                    res_exp = exp_agent.run(
                        messages=messages,
                        knowledge=knowledge,
                        outline=outline,
                        index=str(index),
                        capture=capture,
                        capture_later=capture_later,
                        lang=lang,
                    )
                    chunk = None
                    for chunk in res_exp:
                        yield response + chunk
                    if chunk:
                        response.extend(chunk)
            else:
                pass
