import copy,json
from .utils import connect_logging, get_unique_id, save_json, read_json, save_file
from .models import CallModel
from .config import ConfigScriptRewriter

from .rewriter_prompt import SYSTEM_OUTLINE, USER_OUTLINE
from .rewriter_prompt import SYSTEM_GENSCENE, USER_GENSCENE
from .rewriter_prompt import SYSTEM_REFINE3, USER_REFINE3, SYSTEM_IMPROVE3, USER_IMPROVE3
from .rewriter_prompt import SYSTEM_REFINE4, USER_REFINE4, SYSTEM_IMPROVE4, USER_IMPROVE4

from .output_format import outline_format, extract_outline
from .output_format import script_format, extract_script
from .output_format import refine_format3,refine_format4, extract_refine34
from .output_format import improve_format3,improve_format4, extract_improve3
from .output_format import get_history_refinements, refinements_sort

class Rewriter:

    def __init__(self, config: ConfigScriptRewriter):
        self.debug_list = []
        self.log_file = None

        # 【检查config参数并初始化】

        # 文件读写准备
        if config.unique_id == None:  # 如果没有指定unique_id，则自动生成（默认为获取当前时间）
            config.unique_id = get_unique_id()
        if config.print_log == True:
            self.log_file = connect_logging(save_dir='./logs', name="ScriptRewriter")  # 链接本次操作的保存日志
        if config.output_basename=="": config.output_basename = "no_basename"
        
        # 实际输出路径
        if config.output_save_dir.endswith('/')==False: config.output_save_dir += '/'
        config.output_save_dir += f'{config.unique_id}_{config.output_basename}/'
        # 断点续传需要记录的信息
        self.output_save_dir = config.output_save_dir
        self.basename = config.output_basename
        self.unique_id = config.unique_id

        # 【导入config】
        self.config = config
        self.model_config = config.model_config
        self.model = self.set_model()  # CallModel(config)

        # 【生成相关信息】
        self.rewriter_elements = []
        self.diagram_info = {}
        self.outlines = []
        self.outlines_refine_info = []
        self.refined_outlines = []
        self.screenplays = []
        self.screenplays_refine_info = []
        # self.current_gen_number = 1  # 当前要生成的新集
        # self.final_gen_number = 0 # 最终要生成的总集数

        self.loaded_outlines = False
        self.loaded_screenplays = False
        self.loaded_outlines_refine = False
        if config.load_data == True:  # 如果要断点续传
            self.load_data()
        
        # 【其他参数】
        self.max_retry = 5

    def load_data(self, load_to_number: int = -1):  # 从上次保存的文件中读取数据
        path = self.output_save_dir + 'json_data/'

        try:
            outlines = read_json(path, 'outlines.json')
            if outlines:
                self.outlines = outlines
                self.loaded_outlines = True
                self.datasl(f"Successfully loaded 'outlines' from '{path}'.")
            else: self.datasl(f"Loaded empty 'outlines'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'outlines' from '{path}'. Error message: {str(e)}")

        try:
            outlines_refine_info = read_json(path, 'outlines_refine_info.json')
            if outlines_refine_info:
                self.outlines_refine_info = outlines_refine_info
                self.loaded_outlines_refine = True
                self.datasl(f"Successfully loaded 'outlines_refine_info' from '{path}'.")
            else: self.datasl(f"Loaded empty 'outlines_refine_info'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'outlines_refine_info' from '{path}'. Error message: {str(e)}")
        # if load_to_number == 0:
        #     self.info(f"load_data(): load_to_number=0, do nothing.")
        #     return

        # path = self.output_save_dir + 'json_data/'
        # try:
        #     rewriter_elements = read_json(path, 'rewriter_elements.json')
        #     if rewriter_elements:
        #         self.loaded_outline = True
        #         self.loaded_screenplay = True
        #         self.rewriter_elements = rewriter_elements
        #         self.current_gen_number = len(rewriter_elements) + 1
        #         self.info(
        #             f"Successfully loaded 'rewriter_elements {1}-{self.current_gen_number-1}' from '{path}'.")
        #     else:
        #         self.info(f"Successfully loaded 'rewriter_elements 0")
        #     # if load_to_number >= 1:
        #     #     while len(self.chapter_elements) > load_to_number:
        #     #         self.backspace_data()
        # except FileNotFoundError:
        #     self.info(f"Loaded no data from '{path}'.")

    # def backspace_data(self):
    #     if self.current_gen_number > 1:
    #         self.info(
    #             f"Successfully backspace 'rewriter_elements {1}-{self.current_gen_number-1}' to 'rewriter_elements {1}-{self.current_gen_number-2}'")
    #         self.chapter_elements = self.chapter_elements[:-1]
    #         self.current_gen_number -= 1
    #     else:
    #         self.info("No rewriter_elements data for backspace.")

    # def clear_data(self):  # 清空生成过程中保存的数据
    #     self.rewriter_elements = []
    #     self.current_gen_number = 1

    # 【工具】
    def set_model(self, **kwargs):
        config = self.model_config
        config.model_name = kwargs.get(
            "model_name", self.model_config.model_name)
        return CallModel(config)
    def get_response(self, system_prompt, user_prompt, task, **generate_kwargs): # 把调用对话的函数精简缩短一下
        return self.model.get_response(system_prompt, user_prompt, task=task, save_dir=self.output_save_dir, **generate_kwargs)

    # 【调试信息输出】
    def add_debug(self, text):
        if text != None and text != "":
            print(text)
            self.debug_list.append(text)
            return self.debug_list
        return []
    def info(self, text, tab=0):
        if text != None and text != "":
            if tab: return self.add_debug(f"        [Info] {text}")
            return self.add_debug(f"    [Info] {text}")
        return []
    def warning(self, text, tab=0):
        if text != None and text != "":
            if tab: return self.add_debug(f"        [Warning] {text}")
            return self.add_debug(f"    [Warning] {text}")
        return []
    def datasl(self,text):
        if text != None and text != "":
            return self.add_debug(f"      [Data] {text}")
        return []


    # 【构建情节图信息】
    def build_diagram(self, global_elements, diagram_mode='chapter'):
        """
        Construct a string representation of the plot lines and edges as a diagram.
        Args:
            plot_lines (list): A list of plot lines.
            plots_edges (list): A list of edges in the plot lines.
            mode (str): The mode of the diagram. e.g. 'BFS','DFS','chapter'
        """
        plot_lines = global_elements['refined_plot_lines']
        plots_edges = global_elements['plots_edges']
        from .build_diagram import build_diagram as build
        event_ID_dict,events_list,edges_list,diagram_str = build(plot_lines, plots_edges, diagram_mode if self.ablation!="PlotGraph" else 'chapter')
        return {
            "diagram_mode": diagram_mode,
            "event_ID_dict": event_ID_dict,
            "events_list": events_list,
            "edges_list": edges_list,
            "diagram": diagram_str
        }
    
    # 【改编大纲】
    def get_outline_reference(self, global_elements, diagram_info):
        profiles = copy.deepcopy(global_elements['refined_profiles'])
        event_ID_dict = diagram_info['event_ID_dict']
        for idx,profile in enumerate(profiles):
            CB = profile['Character Biographies']
            for cha,bio in CB.items(): #把角色经历里面的Act x-Event y换成新编号
                for event in CB[cha]:
                    if event.get('Event ID'):
                        if event['Event ID'] in event_ID_dict:
                            event['Event ID'] = event_ID_dict[event['Event ID']]
                        else:
                            self.warning(f"{event['Event ID']} of Profiles is not in the event_ID_dict!",tab=1)
                    else: self.warning(f"A Event of {cha} in Profile {idx+1} has no Event ID!")
        return diagram_info['diagram'],profiles
    def adapt_outline(self, global_elements, screenplay_structure: str = None, **generate_kwargs):
        """
        Adapts the extracted plot lines into a screenplay outline. If no screenplay structure is provided,
        the Agent will choose the most suitable structure.
        """
        if self.loaded_outlines: return self.outlines

        if generate_kwargs.get('temperature') is None:
            generate_kwargs['temperature'] = 0  # Default temperature for generation
        if not self.diagram_info: raise Exception("Diagram is empty!!!")
        task = f'adapt_outline'
        #提取需要的信息
        plot_diagram,profiles = self.get_outline_reference(global_elements, self.diagram_info)
        structure = f"The provided structure setting is {screenplay_structure}." or "No provided structure setting, please select the most appropriate screenplay structure based on the plot diagram."# If screenplay_structure is None, let the Agent choose the most appropriate one
        self.info(f"{task}(): Adapting outlines with structure \"{screenplay_structure or 'Agent-selected'}\".")

        # Iterate through all plot lines to adapt into screenplay outline
        outlines = []
        # 导入prompt
        system_prompt = SYSTEM_OUTLINE.format(output_format = outline_format,)
        user_input = USER_OUTLINE.format(
            # plot_lines=plot_lines,
            plot_diagram=plot_diagram, profiles=profiles,
            screenplay_structure=structure
        )
        for attempt in range(self.max_retry + 1):
            # Call the model
            response = self.get_response(system_prompt, user_input, f"{task}", **generate_kwargs)
            adapted_acts = extract_outline(response)

            if adapted_acts: break # Break the retry loop if successful
            elif attempt < self.max_retry: self.warning(f"{task}(): Failed to extract adapted content! Retry {attempt + 1}/{self.max_retry}...")
        outlines = adapted_acts

        return outlines
    
    # 【大纲self_refine】
    def get_ID(self,location):
        if not location: return None
        act_ID,scene_ID = location.replace("Act ","").replace("Scene ","").split("-")
        return int(act_ID),int(scene_ID)

    def outline_self_refine(self, global_elements, refine_rounds: int = 3, **generate_kwargs):
        if self.loaded_outlines_refine: return self.outlines_refine_info
        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0
        if not self.diagram_info or not self.outlines: raise Exception("Diagram or Outlines is empty!!!")
        #提取需要的信息
        diagram_info = self.diagram_info
        refined_outlines = self.outlines
        plot_diagram,profiles = self.get_outline_reference(global_elements, diagram_info)

        refine_info = []
        all_refinements = []
        for index in range(1,refine_rounds+1):
            
            self.info(f"self_refine(): Round = {index}/{refine_rounds}...")

            #-----------------------------#
            # 1.【对outlines提出修改意见】#
            task = f'self_refine_outlines'
            self.info(f"{task}(): Getting refinements of Outlines...")
            # 提取需要用到的信息
            history = get_history_refinements(all_refinements)
            # 导入prompt
            system_prompt = SYSTEM_REFINE3.format(output_format = refine_format3)
            user_prompt = USER_REFINE3.format(
                outlines = refined_outlines,
                plot_diagram = plot_diagram, profiles = profiles,
                history = history
            )
            # 获取模型响应
            for attempt in range(self.max_retry + 1):
                response = self.get_response(system_prompt, user_prompt, f"{task}_{index}", **generate_kwargs)
                refinements = extract_refine34(response)
                if refinements: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
            refinements,cnt_info = refinements_sort(refinements)
            cnt_ = len(refinements)
            cnt = max(0,cnt_-index+1)
            self.info(f"{task}(): Got {cnt_} refinements({cnt_info}) for Outlines. Will try to improve the first {cnt}.")

            #------------------------------#
            # 2.【根据修改意见提升outlines】#
            task = f'self_improve_outlines'
            self.info(f"{task}(): Getting improvements of Outlines...")
            
            refinements_ok,improvements_ok = [],[]
            for idx in range(cnt):
                refinement = refinements[idx]
                # 提取需要用到的信息
                acts = refined_outlines.get('Adapted Acts List',[])
                location = refinement.get("location") #形如'Act x-Scene y'的字符串
                reference = ""
                try:
                    Act_ID,Scene_ID = self.get_ID(location) #将字符串解析成数字
                    reference = self.get_reference_by_ID(global_elements,diagram_info,acts,Act_ID,Scene_ID)
                except Exception as e: self.warning(f"{task}(): Location '{location}' broken! Delete. (from refinement {idx+1}/{cnt})",tab=1)
                if reference:
                    # 导入prompt
                    system_prompt = SYSTEM_IMPROVE3.format()
                    user_prompt = USER_IMPROVE3.format(
                        suggestion=refinement, reference=reference, 
                        outline=json.dumps(acts[Act_ID-1],indent=4) if self.ablation=='Reference' else acts[Act_ID-1], 
                        output_format=improve_format3
                    )
                    # 获取模型响应
                    for attempt in range(self.max_retry + 1):
                        response = self.get_response(system_prompt, user_prompt, f"{task}_{index}-{idx+1}", **generate_kwargs)
                        improvement = extract_improve3(response)
                        if improvement: break
                        elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...",tab=1)
                    if improvement: # 数据处理
                        if acts[Act_ID-1] == improvement: # 没有做改动
                            self.info(f"{task}(): Improve '{location}' has no changes. Delete.(from refinement {idx+1}/{cnt})",tab=1)
                        else:
                            refinements_ok.append(refinement)
                            improvements_ok.append(improvement)
                            acts[Act_ID-1] = improvement
                            self.info(f"{task}(): Improve '{location}' succeeded. (from refinement {idx+1}/{cnt})",tab=1)
                    else: self.warning(f"{task}(): Improve '{location}' output JSON broken. Delete.(from refinement {idx+1}/{cnt})",tab=1)
                else: self.warning(f"{task}(): Reference of Location '{location}' is empty! Delete. (from refinement {idx+1}/{cnt})",tab=1)
            
            self.info(f"{task}(): Got {len(improvements_ok)} improvements of Outlines.")

            if not improvements_ok:
                self.info(f"self_refine(): No Improvements in Round {index}/{refine_rounds}. Iteration stopped.")
                break
            
            all_refinements.extend(refinements_ok)

            refine_info.append({ # 每个优化轮次的中间信息列表
                "outlines_refinements": refinements, # 本轮次中提出的原始建议列表
                "outlines_refinements_ok": refinements_ok, # 成功执行修改的建议列表
                "outlines_improvements": improvements_ok, # 每次执行修改后的单幕结果列表
                "refined_outlines": refined_outlines, # 本轮次优化后的全幕情节链
            })
            # self.save_data_2_tmp(refine_info)

        return refine_info

    def Act_id2int(self,Act_ID):
        act_index,event_index = Act_ID.replace("Act ","").replace("Event ","").split('-')
        return int(act_index),int(event_index)
    def event_id2int(self,event_id):
        event_index = event_id.replace("Event ","")
        # print("event_id2int: (act,event) = ",act,event)
        return int(event_index)
    
    def build_original(self, plot_lines,diagram_info,raw_chapters,event_int,novel=True,tab=0):
        tab_text = ""
        if tab==1: tab_text = "  "
        if tab==2: tab_text = "    "
        # print("build_original(): diagram_info=",json.dumps(diagram_info,indent = 4))
        key,value = list(diagram_info['event_ID_dict'].items())[event_int-1]
        Act_ID,Event_ID = self.Act_id2int(key)
        # event = plot_lines[Act_ID-1]['Plot Chains List'][Event_ID-1]
        chain = plot_lines[Act_ID-1]['Plot Chains List']
        for event in chain:
            if event['Event ID'] == f"Act {Act_ID}-Event {Event_ID}":
                # print("event = ",event)
                frag_st,frag_mid = plot_lines[Act_ID-1]['Chapter ID Range'].split('-')
                frag_st,frag_mid = int(frag_st),int(frag_mid)
                novel_frags = str(raw_chapters[frag_st: frag_mid+1])
                text = f"""{tab_text}**Event Diegesis**: {event.get('Diegesis')}
        {tab_text}**Event Description**: {event.get('Description')}.
        {tab_text}**Event Detailed**: {event.get('Detailed')}.
        {tab_text}**Associated Characters**: {event.get('Associated Characters')}."""
                if novel==True: text+=f"""{tab_text}**Original Novel Context:**\n{novel_frags}"""
                text+="\n"
                return text
        self.warning(f"Event ID 'Act {Act_ID}-Event {Event_ID}' not found in plot lines")
        return ""


    def get_reference_by_ID(self, global_elements, diagram_info, acts, Act_ID, Scene_ID):
        try:
            act = acts[Act_ID-1]
            scene = act['Scenes List'][Scene_ID-1]
            turning_points = act.get('Turning Points List', [])
            return self.get_reference(global_elements, diagram_info, scene, turning_points)
        except Exception as e:
            self.warning(f"get_reference_by_ID: Error message: {str(e)}",tab=1)
            return ""
        
    def get_reference(self, global_elements, diagram_info, scene, turning_points): # 根据scene确定是否需要turning points信息，同时引入小说和情节图上下文（上下文不给小说原文，防止输入过长）
        # print(f"get_reference(): scene={scene} ,turning_points={turning_points}")
        tp_flag = False
        text = ""
        for tp_scene in turning_points:
            if tp_scene['Scene ID'] == scene['Scene ID']: #本场景为关键节点
                text += f"- Current scene is a turning_points: {tp_scene['Reason']}\n"
                tp_flag = True
        try:
            plot_lines = global_elements['refined_plot_lines']
            # profiles = global_elements['refined_profiles']
            # plots_edges = global_elements['plots_edges']
            raw_chapters = global_elements['raw_chapters']
        except Exception as e: self.warning(f"build_original: global_elements is broken. Error message: {str(e)}",tab=1)
        # print("text = ",text)
        events = scene.get('Associated Events List',[])
        if not events: self.warning(f"Associated Events List of '{scene['Scene ID']}' is empty!",tab=1)
        text += f"- Current scene has following associated original events: {str(events)}\n"
        relate_edges = []
        for event_id in events:
            try:
                plot = self.build_original(
                    plot_lines,diagram_info,raw_chapters,self.event_id2int(event_id),
                    novel=True if self.ablation!="Reference" else False,tab=1) #没有消融时才给原文
                text += f"- Event details of '{event_id}':\n{plot}\n"
            except Exception as e: self.warning(f"build_original_plot: Event ID '{event_id}' is broken. Error message: {str(e)}",tab=1)
            
            if self.ablation!="Reference": #没有消融时才给更多上下文
                if tp_flag: #对于关键事件，再额外增加一些上下文
                    for act in diagram_info['edges_list']: #把这个关键事件的前驱都加进去
                        for edge in act:
                            if edge[1] == event_id:
                                relate_edges.append(edge)
        # print("text = ",text)
        if relate_edges: text += f"- Current scene has following causally associated original event: {str([edge[0] for edge in relate_edges])}\n"
        for edge in relate_edges:
            text += f"- Cause of '{edge[1]}' is '{edge[0]}': {edge[2]}\n"
            try:
                plot = self.build_original(plot_lines,diagram_info,raw_chapters,self.event_id2int(edge[0]),novel=False,tab=2)
                text += f"- Event details of '{edge[0]}':\n{plot}\n"
            except Exception as e: self.warning(f"build_original_causeplot: Event ID '{edge[0]}' is broken. Error message: {str(e)}",tab=1)
        # print("text = ",text)
        return text

    # def calc_scenes(self, outlines): #计算outlines中的scene总数
    #     acts = outlines.get('Adapted Acts List',[])
    #     return sum([len(act.get('Scenes List',[])) for act in acts])

    # 【生成剧本】
    def build_outline(self,scene,tab=0):
        # print("build_outline start.")
        tab_text = "  " if tab else ""
        return f"""{tab_text}**Place and Time**: {scene.get('Place and Time')}
{tab_text}**Background**: {scene.get('Background')}.
{tab_text}**Storyline**: {scene.get('Storyline')}.
{tab_text}**Storyline Goal**: {scene.get('Storyline Goal')}.
{tab_text}**Character Experiences**: {scene.get('Character Experiences')}.
{tab_text}**End Suspense**: {scene.get('End Suspense')}.
{tab_text}**Associated Events List**: {scene.get('Associated Events List')}.
"""
    def build_outlines(self,scenes,act_id,tab=0):
        tab_text = "  " if tab else ""
        text = ""
        for idx,scene in enumerate(scenes):
            text+=f"""ACT {act_id+1} SCENE {idx+1}:
{tab_text}- **Place and Time**: {scene.get('Place and Time')}.
{tab_text}- **Background**: {scene.get('Background')}.
{tab_text}- **Storyline**: {scene.get('Storyline')}.
{tab_text}- **Storyline Goal**: {scene.get('Storyline Goal')}.
{tab_text}- **Character Experiences**: {scene.get('Character Experiences')}.
{tab_text}- **End Suspense**: {scene.get('End Suspense')}.
{tab_text}- **Associated Events List**: {scene.get('Associated Events List')}.
"""
        return text
    
    def get_scene_by_id(self,outlines,scene_id):
        act,scene = scene_id.replace("Act ","").replace("Scene ","").split('-')
        return outlines['Adapted Acts List'][int(act)-1]['Scenes List'][int(scene)-1]
    
    def get_screenplay_history(self, outlines, screeplays):
        # print("screenplay_history start.")
        text = ""
        for idx,screenplay in enumerate(screeplays):
            scene_id,script = screenplay['Scene ID'],screenplay['script']
            if idx < len(screeplays)-1:
                try:outline = self.build_outline(self.get_scene_by_id(outlines,scene_id),tab=1)
                except Exception as e: f"get_screenplay_history: Scene ID '{scene_id}' is broken. Error message: {str(e)}"
                text += f"- Outline of '{scene_id}':\n{outline}\n"#前面的给大纲
            else: text += f"- Script of '{scene_id}':\n{script}\n" #最新的上一集给剧本
        # print("screenplay_history done.")
        return text

    def generate_screenplay(self, global_elements, refine_rounds: int = 3, **generate_kwargs):
        """
        Generates the full screenplay based on the adapted outlines and scene context.
        """
        if self.loaded_screenplays: return self.screenplays # If the screenplay has been generated, return it

        if generate_kwargs.get('temperature') is None:
            generate_kwargs['temperature'] = 0  # Default temperature for generation

        if (self.ablation!="PlotGraph" and not self.diagram_info) or not self.refined_outlines: raise Exception("Diagram or Outlines is empty!!!")
        diagram_info = self.diagram_info
        outlines = self.refined_outlines
        acts = outlines.get('Adapted Acts List',[])
        total_scenes = sum([len(act.get('Scenes List',[])) for act in acts])
        self.info(f"generate_screenplay(): Total {len(acts)} acts with {total_scenes} scenes in outlines.")
        screenplays,screenplays_refine_info = [],[]
        # Iterate through each act in the adapted outline
        for act_id,act in enumerate(acts):
            scenes = act.get('Scenes List', [])

            # 1.【生成剧本】
            task = f"gen_scripts_Act{act_id+1}"
            self.info(f"{task}(): Total {len(scenes)} scenes in Act {act_id+1}/{len(acts)}...")
            turning_points = act.get('Turning Points List', [])
            scripts = []
            # Iterate through each scene in the act
            for scene_id,scene in enumerate(scenes):
                self.info(f"{task}(): Generating script scene {scene_id+1}/{len(scenes)} of Act {act_id+1}/{len(acts)}...",tab=1)
                # 导入prompt
                system_prompt = SYSTEM_GENSCENE.format()
                user_input = USER_GENSCENE.format(
                    current_outline=self.build_outline(scene),
                    previous_scenes=self.get_screenplay_history(outlines,screenplays+scripts),
                    related_references=self.get_reference(global_elements, diagram_info, scene, turning_points),
                    output_example=script_format
                )
                for attempt in range(self.max_retry + 1):
                    # Call the model to get the screenplay for the current scene
                    response = self.get_response(system_prompt, user_input, f"{task}_{scene_id+1}", **generate_kwargs)
                    script = extract_script(response)
                    if script: break # Break the retry loop if successful
                    elif attempt < self.max_retry:
                        self.warning(f"{task}(): Failed extract_content! Retry {attempt + 1}/{self.max_retry}...",tab=1)

                # Append the generated scene script to the final screenplay
                scripts.append({"Scene ID": f"Act {act_id+1}-Scene {scene_id+1}", "script": script})
            
            self.save_screenplay(scripts,act_id+1) #保存本幕
            # screenplays.extend(scripts)
            
            refined_scripts = scripts
            scripts_refine_info = []
            all_refinements = []
            for index in range(1,refine_rounds+1):
                self.info(f"self_refine(): Round = {index}/{refine_rounds}...")

                #-----------------------------#
                # 2.【对script提出修改意见】#
                task = f"self_refine_scripts_Act{act_id+1}"
                self.info(f"{task}(): Getting refinements of Scripts...")
                # 提取需要用到的信息
                history = get_history_refinements(all_refinements)
                # 导入prompt
                system_prompt = SYSTEM_REFINE4.format(output_format = refine_format4)
                user_prompt = USER_REFINE4.format(
                    scripts = refined_scripts,
                    act_outlines = self.build_outlines(scenes,act_id),
                    history = history
                )
                # 获取模型响应
                for attempt in range(self.max_retry + 1):
                    response = self.get_response(system_prompt, user_prompt, f"{task}_{index}", **generate_kwargs)
                    refinements = extract_refine34(response)
                    if refinements: break
                    elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
                refinements,cnt_info = refinements_sort(refinements)
                cnt_ = len(refinements)
                cnt = max(0,cnt_-index+1)
                self.info(f"{task}(): Got {cnt_} refinements({cnt_info}) for Scripts. Will try to improve the first {cnt}.")

                #------------------------------#
                # 2.【根据修改意见提升script】#
                task = f"self_improve_scripts_Act{act_id+1}"
                self.info(f"{task}(): Getting improvements of Scripts...")
                
                refinements_ok,improvements_ok = [],[]
                for idx in range(cnt):
                    refinement = refinements[idx]
                    # 提取需要用到的信息
                    location = refinement.get("location") #形如'Act x-Scene y'的字符串
                    reference = ""
                    try:
                        Act_ID,Scene_ID = self.get_ID(location) #将字符串解析成数字
                        reference = self.get_reference_by_ID(global_elements, diagram_info, acts, Act_ID, Scene_ID)
                    except Exception as e: self.warning(f"{task}(): Location '{location}' broken! Delete. (from refinement {idx+1}/{cnt})",tab=1)
                    if reference:
                        # 导入prompt
                        system_prompt = SYSTEM_IMPROVE4.format()
                        user_prompt = USER_IMPROVE4.format(
                            suggestion=refinement, reference=reference, 
                            script=scripts[Scene_ID-1], 
                            output_format=improve_format4
                        )
                        # 获取模型响应
                        for attempt in range(self.max_retry + 1):
                            response = self.get_response(system_prompt, user_prompt, f"{task}_{index}-{idx+1}", **generate_kwargs)
                            improvement = extract_improve3(response)
                            if improvement: break
                            elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...",tab=1)
                        if improvement: # 数据处理
                            if scripts[Scene_ID-1] == improvement: # 没有做改动
                                self.info(f"{task}(): Improve '{location}' has no changes. Delete.(from refinement {idx+1}/{cnt})",tab=1)
                            else:
                                refinements_ok.append(refinement)
                                improvements_ok.append(improvement)
                                scripts[Scene_ID-1] = improvement
                                self.info(f"{task}(): Improve '{location}' succeeded. (from refinement {idx+1}/{cnt})",tab=1)
                        else: self.warning(f"{task}(): Improve '{location}' output JSON broken. Delete.(from refinement {idx+1}/{cnt})",tab=1)
                    else: self.warning(f"{task}(): Reference of Location '{location}' is empty! Delete. (from refinement {idx+1}/{cnt})",tab=1)
                
                self.info(f"{task}(): Got {len(improvements_ok)} improvements of Scripts.")

                if not improvements_ok:
                    self.info(f"self_refine(): No Improvements in Round {index}/{refine_rounds}. Iteration stopped.")
                    break
                
                all_refinements.extend(refinements_ok)

                scripts_refine_info.append({ # 每个优化轮次的中间信息列表
                    "scripts_refinements": refinements, # 本轮次中提出的原始建议列表
                    "scripts_refinements_ok": refinements_ok, # 成功执行修改的建议列表
                    "scripts_improvements": improvements_ok, # 每次执行修改后的单幕结果列表
                    "refined_scripts": refined_scripts, # 本轮次优化后的全幕情节链
                })
                
                self.save_screenplay(refined_scripts,act_id+1,"refined_") #每轮次的refine结束时保存本幕
            screenplays.extend(refined_scripts) #保存本幕
            self.save_screenplay_refine(scripts_refine_info,act_id+1)
            # screenplays_refine_info.append(scripts_refine_info)

        self.info(f"generate_screenplays(): Total generated {len(screenplays)} Scenes.")
        return screenplays#,screenplays_refine_info

    def calc_refine_times(self, refine_info,label):
        cnt1 = [len(refine[f'{label}_refinements']) for refine in refine_info]
        sum1 = sum(cnt1)
        self.info(f"Total get refinements: {str(cnt1)}  = {sum1}")

        cnt1 = [len(refine[f'{label}_improvements']) for refine in refine_info]
        sum1 = sum(cnt1)
        self.info(f"Total use improvements: {str(cnt1)} = {sum1} ")

    def script_rewriter(self, global_elements, diagram_mode = 'chapter', screenplay_structure = None, refine_rounds: int = 3, ablation = None, **generate_kwargs):
        self.rewriter_elements = {}
        self.ablation = ablation

        #0.建立情节图
        diagram_info = {}
        try:
            diagram_info = self.build_diagram(global_elements, diagram_mode)
            self.diagram_info = diagram_info
            self.save_data_0(diagram_info)
            self.info("Successfully build diagram.")
        except Exception as e: self.warning(f"Failed to build diagram. Error message: {str(e)}")

        # return {'rewriter_elements': self.rewriter_elements, 'message': 'Success'}

        #1.改编成剧本大纲
        outlines = []
        try:
            outlines = self.adapt_outline(global_elements, screenplay_structure, **generate_kwargs)
            self.outlines = outlines
            if self.loaded_outlines == False: self.save_data_1(outlines)# 如果没有加载到数据，则保存当前数据
            self.info("Successfully adapt outlines.")
        except Exception as e: self.warning(f"Failed to adapt outlines. Error message: {str(e)}")

        # return {'rewriter_elements': self.rewriter_elements, 'message': 'Success'}

        #2.剧本大纲self_refine
        outlines_refine_info = []
        try:
            outlines_refine_info = self.outline_self_refine(global_elements, refine_rounds,**generate_kwargs)
            self.outlines_refine_info = outlines_refine_info
            if self.loaded_outlines_refine == False: self.save_data_2(outlines_refine_info)# 如果没有加载到数据，则保存当前数据
            self.info(f"Successfully self refine outlines {len(outlines_refine_info)} rounds.")
            self.calc_refine_times(outlines_refine_info,'outlines')
        except Exception as e: self.warning(f"Failed to self refine outlines. Error message: {str(e)}")

        self.refined_outlines = outlines if not outlines_refine_info else outlines_refine_info[-1]["refined_outlines"]
        if self.refined_outlines: self.save_data_3(self.refined_outlines)

        # return {'rewriter_elements': self.rewriter_elements, 'message': 'Success'}

        #3.生成剧本、self_refine
        screenplays = []
        try:
            screenplays = self.generate_screenplay(global_elements, refine_rounds, **generate_kwargs)
            self.screenplays = screenplays
            # if self.loaded_screenplays == False: self.save_data_4(screenplays)# 如果没有加载到数据，则保存当前数据
            self.info("Successfully generate screenplays.")
        except Exception as e: self.warning(f"Failed to generate screenplays. Error message: {str(e)}")
        
        if (self.diagram_info or ablation=="PlotGraph") and self.outlines and (self.outlines_refine_info or refine_rounds==0)  and self.screenplays:# and self.outlines_refine_info and self.screenplays and self.screenplay_refine_info:
            self.rewriter_elements = {
                "diagram_info": diagram_info,
                "outlines": outlines, #初始提取出的剧本大纲
                "outlines_refine_info": outlines_refine_info, #每轮优化的中间信息
                "refined_outlines": self.refined_outlines, #优化后的最终剧本大纲
                "screenplays": screenplays
            }
            # self.save_data_all(self.rewriter_elements)
            self.info("Successfully rewriter scripts.")
            return {'rewriter_elements': self.rewriter_elements, 'message': 'Success'}
        else: return {'rewriter_elements': self.rewriter_elements, 'message': 'Fail'}
    
    def save_data_0(self, diagram_info):
        path = self.output_save_dir + 'json_data/'
        save_json(diagram_info, path, 'diagram_info')
        self.datasl(f"Data 'diagram_info.json' has been saved to '{path}'.")

        text = f"""[diagram_mode]: {diagram_info['diagram_mode']}\n
[event_ID_dict]:\n{json.dumps(diagram_info['event_ID_dict'],indent=4)}\n
[diagram]:\n\n{diagram_info['diagram']}
"""
        save_file(text,path,"diagram",type='txt')
        self.datasl(f"Data 'diagram.txt' has been saved to '{path}'.")
    
    def save_data_1(self, outlines):
        path = self.output_save_dir + 'json_data/'
        save_json(outlines, path, 'outlines')
        self.datasl(f"Data 'outlines.json' has been saved to '{path}'.")
    
    def save_data_2(self, outlines_refine_info):
        path = self.output_save_dir + 'json_data/'
        save_json(outlines_refine_info, path, 'outlines_refine_info')
        self.datasl(f"Data 'outlines_refine_info.json' has been saved to '{path}'.")
    
    def save_data_3(self, refined_outlines):
        path = self.output_save_dir + 'json_data/'
        save_json(refined_outlines, path, 'refined_outlines')
        self.datasl(f"Data 'refined_outlines.json' has been saved to '{path}'.")

    # def save_data_2_tmp(self, outlines_refine_info_tmp):
    #     path = self.output_save_dir + 'json_data/'
    #     save_json(outlines_refine_info_tmp, path, 'outlines_refine_info_tmp')
    #     self.datasl(f"Data 'outlines_refine_info_tmp.json' has been saved to '{path}'.")
    
    # def save_data_4(self, screenplays):
    #     path = self.output_save_dir + 'json_data/'
    #     save_json(screenplays, path, 'screenplays')
    #     self.datasl(f"Data 'screenplays.json' has been saved to '{path}'.")
    
    def save_screenplay(self, screenplay, _id=1, label = ""):
        path = self.output_save_dir + 'docx_data/'
        cnt=len(screenplay)
        full_text=f"The {label}Sreenplay Act {_id} of '{self.basename}':\n"
        for i in range(cnt):
            episode = screenplay[i]['Scene ID']
            script = screenplay[i]['script']
            full_text+=f"""\n[{episode}]:\n{script}\n"""
        save_file(full_text,path,f'{label}screenplay_{_id}','docx')
        self.datasl(f"Data '{label}screenplay_{_id}.docx' has been saved to '{path}'.")

    def save_screenplay_refine(self, screenplays_refine_info, _id=1):
        path = self.output_save_dir + 'json_data/screenplays/'
        save_json(screenplays_refine_info, path, f'screenplays_refine_info_{_id}')
        self.datasl(f"Data 'screenplays_refine_info_{_id}.json' has been saved to '{path}'.")
