import re,copy,json
from .utils import read_file, replace_chapter_label, connect_logging, get_unique_id, path_basename, save_json, read_json, save_file, draw_DAG
from .models import CallModel
from .config import ConfigGlobalExtract

from .prompt import ADD_PLOTLINE_PARA, SYSTEM_PLOTLINE, USER_PLOTLINE
from .prompt import SYSTEM_PROFILE, USER_PROFILE
from .prompt import SYSTEM_REFINE, USER_REFINE, SYSTEM_REFINE2, USER_REFINE2
from .prompt import SYSTEM_IMPROVE, USER_IMPROVE, SYSTEM_IMPROVE2, USER_IMPROVE2
from .prompt import SYSTEM_DAG, USER_DAG

from .output_format import plotline_format,extract_plotline
from .output_format import profile_format,extract_profile
from .output_format import refine_format1,refine_format2, extract_refine
from .output_format import improve_format1,improve_format2, extract_improve
from .output_format import get_history_refinements, refinements_sort
from .output_format import dag_format,extract_DAG

class GlobalExtractor:

    def __init__(self, config: ConfigGlobalExtract):
        self.debug_list = []
        self.log_file = None

        # 【检查config参数并初始化】

        # 文件读写准备
        if config.unique_id == None:  # 如果没有指定unique_id，则自动生成（默认为获取当前时间）
            config.unique_id = get_unique_id()
        if config.print_log == True:
            self.log_file = connect_logging(save_dir='./logs', name="GlobalExtractor",
                                            # unique_id=config.unique_id
                                            )  # 链接本次操作的保存日志
        self.upload_novel = {}
        self.loaded_upload_novel = False
        if config.load_data==True: #如果要断点续传
            self.load_data_0(f'{config.output_save_dir}/{config.unique_id}_{config.output_basename}/') #加载之前上传的小说

        self.upload_novel = self.read_process_file(config.novel_path)
        self.raw_chapters = self.upload_novel['chapters']

        if config.output_basename=="": # 如果没有指定输出文件名，则自动生成（默认为提交文件名）
            config.output_basename = path_basename(config.novel_path)
        
        # 实际输出路径
        if config.output_save_dir.endswith('/')==False: config.output_save_dir += '/'
        config.output_save_dir += f'{config.unique_id}_{config.output_basename}/'
        # 断点续传需要记录的信息
        self.output_save_dir = config.output_save_dir
        self.basename = config.output_basename
        self.unique_id = config.unique_id
        if self.loaded_upload_novel == False: # 如果没有加载到数据，则保存当前数据
            self.save_data_0(self.upload_novel)

        # 【导入config】
        self.config = config
        self.model_config = config.model_config
        self.model = self.set_model()  # CallModel(config)

        # 【生成相关信息】
        self.plot_lines = []
        self.profiles = []
        self.refine_info = []
        self.refined_plot_lines = []
        self.refined_profiles = []
        self.plots_edges = {}

        self.loaded_plot_lines = False
        self.loaded_profiles = False
        self.loaded_refine = False
        self.loaded_DAG = False
        if config.load_data==True: #如果要断点续传
            self.load_data() #加载之前生成的数据

        # 【其他参数】
        self.max_retry = 5

    # 【数据加载】
    def load_data_0(self,output_save_dir): # 从上次保存的文件中读取数据
        path = output_save_dir + 'json_data/'
        try:
            upload_novel = read_json(path,'upload_novel.json')
            if upload_novel:
                self.upload_novel = upload_novel
                self.loaded_upload_novel = True
                self.datasl(f"Successfully loaded 'upload_novel' from '{path}'.")
            else: self.datasl(f"Loaded empty 'upload_novel'")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'upload_novel' from '{path}'. Error message: {str(e)}")

    def load_data(self):  # 从上次保存的文件中读取数据

        path = self.output_save_dir + 'json_data/'
        try:
            plot_lines = read_json(path, 'plot_lines.json')
            if plot_lines:
                self.plot_lines = plot_lines
                self.loaded_plot_lines = True
                self.datasl(f"Successfully loaded 'plot_lines' from '{path}'.")
            else: self.datasl(f"Loaded empty 'plot_lines'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'plot_lines' from '{path}'. Error message: {str(e)}")

        try:
            profiles = read_json(path, 'profiles.json')
            if profiles:
                self.profiles = profiles
                self.loaded_profiles = True
                self.datasl(f"Successfully loaded 'profiles' from '{path}'.")
            else: self.datasl(f"Loaded empty 'profiles'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'profiles' from '{path}'. Error message: {str(e)}")

        try:
            refine_info = read_json(path, 'refine_info.json')
            if refine_info:
                self.refine_info = refine_info
                self.loaded_refine = True
                self.datasl(f"Successfully loaded 'refine_info' from '{path}'.")
            else: self.datasl(f"Loaded empty 'refine_info'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'refine_info' from '{path}'. Error message: {str(e)}")

        try:
            plots_edges = read_json(path, 'plots_edges.json')
            if plots_edges:
                self.plots_edges = plots_edges
                self.loaded_DAG = True
                self.datasl(f"Successfully loaded 'plots_edges' from '{path}'.")
            else: self.datasl(f"Loaded empty 'plots_edges'.")
        except FileNotFoundError as e: self.datasl(f"Loaded no 'plots_edges' from '{path}'. Error message: {str(e)}")

        if not self.plot_lines and not self.profiles and not self.refine_info and not self.plots_edges:
            self.datasl(f"Loaded no data from '{path}'. Please check the path or the data file.")

    # 【工具】
    def set_model(self, **kwargs):
        config = self.model_config
        config.model_name = kwargs.get(
            "model_name", self.model_config.model_name)
        return CallModel(config)

    def read_process_file(self, novel_path):
        if self.loaded_upload_novel==True: return self.upload_novel

        novel_text = read_file(novel_path)  # 读取文件
        if novel_text == "": raise ValueError(f"Failed to read novel from '{novel_path}'. Please check the path or the data file.")

        novel_text = replace_chapter_label(novel_text, '\nChapter ')  # 将全文中各种稀奇古怪的章节标签替换成'Chapter'

        chapters = novel_text.split("Chapter")
        chapters = [chapter.strip() for chapter in chapters]
        # chapter_begin = 1  # 舍弃最前面的一段
        upload_novel = {'novel_path': novel_path, 'chapter_num': len(chapters)-1, 'chapters': chapters, 'novel_text': novel_text}
        
        self.info("Successfully read novel. Total {} chapters.".format(len(chapters)-1))
        return upload_novel

    def get_response(self, system_prompt, user_prompt, task, **generate_kwargs): # 把调用对话的函数精简缩短一下
        return self.model.get_response(system_prompt, user_prompt, task=task, save_dir=self.output_save_dir, **generate_kwargs)
    
    # 【调试信息输出】
    def add_debug(self, text):
        if text != None and text != "":
            print(text)
            self.debug_list.append(text)
            return self.debug_list
        return []
    def info(self, text, tab=0):
        if text != None and text != "":
            if tab: return self.add_debug(f"        [Info] {text}")
            return self.add_debug(f"    [Info] {text}")
        return []
    def warning(self, text, tab=0):
        if text != None and text != "":
            if tab: return self.add_debug(f"        [Warning] {text}")
            return self.add_debug(f"    [Warning] {text}")
        return []
    def datasl(self,text):
        if text != None and text != "":
            return self.add_debug(f"      [Data] {text}")
        return []

    # 【Act x-Event x 与 Event x之间的转换】
    def ActID(self,ss,ID=0):
        # print(f"ActID_ss: ID={ID}, ss=",ss)
        if type(ss)!=str: raise Exception(f'ActID: input must be str, but got: {ss}.')
        if ss.startswith('Act'): ss = ss.split('-')[1] #删除Act x
        if ID: ss = f'Act {ID}-{ss}'#添加Act ID
        return ss

    def ActID_List(self,List,ID=0):
        # print(f"ActID_List: ID={ID}, List=",json.dumps(List,indent=4))
        if type(List)!=dict and type(List)!=list: raise Exception(f'ActID: input must be dict or list, but got: {List}.')
        if List and type(List[0])!=str: raise Exception(f'ActID: input must be str list, but got List[0]: {List[0]}.')
        return [self.ActID(element,ID) for element in List]
    
    # def ActID_CR(self,CR,ID=0): # Causal Relationship
    #     # print(f"ActID_CR: ID={ID}, CR=",json.dumps(CR,indent=4))
    #     CR['Preceding Events'] = self.ActID_List(CR['Preceding Events'],ID)
    #     CR['Subsequent Events'] = self.ActID_List(CR['Subsequent Events'],ID)
    #     return CR

    def ActID_Events(self,events,ID=0):
        # print(f"ActID_Events: ID={ID}, events=",json.dumps(events,indent=4))
        for event in events:
            event['Event ID'] = self.ActID(event['Event ID'],ID)
            # if event.get('Causal Relationship'):
                # event['Causal Relationship'] = self.ActID_CR(event['Causal Relationship'],ID)
        return events

    def ActID_Plotline(self,plotline,ID=0):
        # print(f"ActID_Plotline: ID={ID}, plotline=",json.dumps(plotline,indent=4))
        plotline['Plot Chains List'] = self.ActID_Events(plotline['Plot Chains List'],ID)
        plotline['Main Plots List'] = self.ActID_List(plotline['Main Plots List'],ID)
        plotline['Minor Plots List'] = self.ActID_List(plotline['Minor Plots List'],ID)
        return plotline

    def ActID_Profile(self,profile,ID=0):
        CB = profile['Character Biographies']
        for cha,bio in CB.items():
            CB[cha] = self.ActID_Events(CB[cha],ID)
        profile['Character Biographies'] = CB
        return profile

    # 【核心函数】：

    def extract_plot_lines(self, chapter_nums: int = 1, chapter_windows: int = 2, plot_chain_limit: str = None, **generate_kwargs):
        # 在长度为chapter_windows的窗口段落中，取前chapter_nums个章节提取出一个act
        
        if self.loaded_plot_lines: return self.plot_lines

        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0

        task = f'extract_plot_lines'
        # fragments = self.merge_chapters(fragments_range)
        raw_chapters = self.raw_chapters[1:] #用读取到的原始章节作为段
        self.info(f"{task}(): Total {len(raw_chapters)} chapters. Window Size: {chapter_windows} chapters. Act Size: {chapter_nums} chapters.")
        plot_lines,frag_range = [],[]
        for frag_st in range(0,len(raw_chapters),chapter_nums):
            index = len(plot_lines)+1

            frag_mid = min(frag_st + chapter_nums, len(raw_chapters))
            frag_ed = min(frag_st + chapter_windows, len(raw_chapters))
            novel_frags_cur = "\n\n".join(raw_chapters[frag_st: frag_mid]) #窗口前半（改编内容）
            novel_frags_ref = "\n\n".join(raw_chapters[frag_mid: frag_ed]) #窗口后半（后续参考）
            frag_range.append(f"{frag_st+1}-{frag_mid}")

            self.info(f"{task}(): Act {index} from the first {chapter_nums} chapters of chapter {frag_st+1}-{frag_ed}.")

            # 提取需要用到的信息
            previous_act = None if not plot_lines else plot_lines[-1]

            # 导入prompt
            system_prompt = SYSTEM_PLOTLINE.format(output_format=plotline_format)
            user_input = USER_PLOTLINE.format(
                novel_frags_cur=novel_frags_cur, novel_frags_ref=novel_frags_ref, previous_act=previous_act
            )
            if plot_chain_limit:
                user_input += ADD_PLOTLINE_PARA.format(plot_chain_limit=plot_chain_limit)
            for attempt in range(self.max_retry+1):
                # 获取模型响应
                response = self.get_response(system_prompt, user_input, f"{task}_{index}", **generate_kwargs)
                act = extract_plotline(response)
                if act: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
            plot_lines.append(act)

        for idx in range(len(plot_lines)):
            plot = {'Act ID': idx+1,'Chapter ID Range': frag_range[idx]} #把ACT ID放最前面
            plot.update(plot_lines[idx])
            plot_lines[idx] = self.ActID_Plotline(plot, idx+1) # 将plot_lines中的所有 Event x 手动改成 Act x-Event x

        return plot_lines

    def extract_profiles(self, **generate_kwargs):
        # 在长度为chapter_windows的窗口段落中，取前chapter_nums个章节提取出一个act，基于这个act提取出profile
        
        if self.loaded_profiles: return self.profiles

        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0

        if not self.plot_lines: raise Exception("Plot_lines is empty!!!")
        
        task = f'extract_profiles'
        plot_lines = self.plot_lines
        # raw_chapters = self.raw_chapters[1:] #用读取到的原始章节作为段
        self.info(f"{task}(): Total {len(plot_lines)} acts.")
        profiles = []
        for ind,act in enumerate(plot_lines):
            index = ind +1
            frag_st,frag_mid = act['Chapter ID Range'].split('-')
            frag_st,frag_mid = int(frag_st),int(frag_mid)
            novel_frags_cur = "\n\n".join(self.raw_chapters[frag_st: frag_mid+1]) #窗口前半（改编内容）

            self.info(f"{task}(): Profile {index} from Act {index} and chapter {frag_st}-{frag_mid}.")

            # 提取需要用到的信息
            basic_info = act['Basic Content']
            # events_list = act['Plot Chains List']
            events_list = copy.deepcopy(act['Plot Chains List']) # 这里用深拷贝，不然下面两行会改变self.plot_lines的原始数据
            # for event in events_list: event.popitem() # 这里不需要Causal Relationship
            events_list = self.ActID_Events(events_list)# 这里把events_list中的所有 Act x-Event x 手动改回 Event x以避免多余信息的干扰

            # 导入prompt
            system_prompt = SYSTEM_PROFILE.format(output_format = profile_format)
            user_input = USER_PROFILE.format(
                novel_frags = novel_frags_cur, basic_info = basic_info, events_list = events_list
            )
            profile={}
            for attempt in range(self.max_retry+1):
                # 获取模型响应
                response = self.get_response(system_prompt, user_input, f"{task}_{index}", **generate_kwargs)
                profile = extract_profile(response)
                if profile: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
            profiles.append(profile)

        for idx in range(len(profiles)):
            plot = {'Act ID': idx+1} #把ACT ID放最前面
            plot.update(profiles[idx])
            profiles[idx] = self.ActID_Profile(plot, idx+1) # 将profile中的所有 Event x 手动改成 Act x-Event x

        return profiles
    
    def get_Act_ID(self,location):
        if not location: return None
        match = re.search(r'\d+', location)
        act_num = match.group(0) if match else None
        if act_num: return int(act_num)
        else: return None
    def get_context(self,Act_ID):
        if self.ablation=="Reference": return "No reference context." #消融
        if not Act_ID: return ""
        try:
            act = self.plot_lines[Act_ID-1]
            frag_st,frag_mid = act['Chapter ID Range'].split('-')
            return "\n\n".join(self.raw_chapters[int(frag_st): int(frag_mid)+1]) #窗口前半（改编内容）
        except: return ""
    
    def self_refine(self, refine_rounds: int = 3, **generate_kwargs):
        """
        Refine the plot lines and character profiles multiple times up to the refine_rounds.
        Self-refine extracted plot lines and character lists based on custom checks.
        第一步，提示大模型检查所有幕中的情节和幕与幕之间的情节是否保持一致、情节事件是否存在重复或遗漏、因果关系是否正确，并且所有幕中的角色经历是否与这一幕情节吻合，如果有问题，则需要给出相应问题和改进建议（例如改写哪一幕的情节事件和角色，或者删除重复和不合理的情节和角色设置）；
        第二步，提示大模型根据前面给出的改进建议和根据问题情节检索到的相关章节上下文进一步修改生成的情节链和角色经历列表
        Step 1: Initial check for potential issues in plot lines and character lists
        Step 2: Generate a self-refinement prompt for the model
        Step 3: Get refinement suggestions from the model
        Step 4: Apply the refinements

        :param plot_lines: The list of extracted plot lines to refine
        :param generate_kwargs: Additional keyword arguments to pass to the model
        :return: The refined list of plot lines
        
        :param refine_rounds: The maximum number of self-refinement rounds.
        """

        if self.loaded_refine: return self.refine_info
        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0
        if not self.plot_lines or not self.profiles: raise Exception("Plot_lines or Profiles is empty!!!")
        refined_plot_lines = self.plot_lines
        refined_profiles = self.profiles
        # for idx,plot in enumerate(refined_plot_lines): #这样写self.xxx也会被更改
            # del plot['Chapter ID Range'] #删掉暂不需要的辅助信息

        refine_info = []
        all_refinements1,all_refinements2 = [],[]

        for index in range(1,refine_rounds+1):
            
            self.info(f"self_refine(): Round = {index}/{refine_rounds}...")

            #-----------------------------#
            # 1.【对plotlines提出修改意见】# 
            task = f'self_refine_plotlines'
            self.info(f"{task}(): Getting refinements of Plot Lines...")
            # 提取需要用到的信息
            history1 = get_history_refinements(all_refinements1)
            # 导入prompt
            system_prompt = SYSTEM_REFINE.format(output_format = refine_format1)
            user_prompt = USER_REFINE.format(
                plot_lines = refined_plot_lines, history1 = history1, profiles = refined_profiles,
            )
            # 获取模型响应
            for attempt in range(self.max_retry + 1):
                response = self.get_response(system_prompt, user_prompt, f"{task}_{index}", **generate_kwargs)
                refinements1 = extract_refine(response)
                if refinements1: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
            refinements1,cnt1_info = refinements_sort(refinements1)
            cnt1_ = len(refinements1)
            cnt1 = max(0,cnt1_-index+1)
            self.info(f"{task}(): Got {cnt1_} refinements({cnt1_info}) for Plot Lines. Will try to improve the first {cnt1}.")

            #------------------------------#
            # 2.【根据修改意见提升plotlines】#
            task = f'self_improve_plotlines'
            self.info(f"{task}(): Getting improvements of Plot Lines...")
            
            refinements1_ok,improvements1_ok = [],[]
            for idx in range(cnt1):
                refinement = refinements1[idx]
                # 提取需要用到的信息
                location = refinement.get("location") #形如'Act x'的字符串
                Act_ID = self.get_Act_ID(location) #将字符串解析成数字
                context = self.get_context(Act_ID)
                if context:
                    # 导入prompt
                    system_prompt = SYSTEM_IMPROVE.format()
                    user_prompt = USER_IMPROVE.format(
                        suggestion=refinement, context=context, profile=refined_profiles[Act_ID-1],
                        plot_line=json.dumps(refined_plot_lines[Act_ID-1],indent=4) if self.ablation=='Reference' else refined_plot_lines[Act_ID-1], 
                        output_format=improve_format1
                    )
                    # 获取模型响应
                    for attempt in range(self.max_retry + 1):
                        response = self.get_response(system_prompt, user_prompt, f"{task}_{index}-{idx+1}", **generate_kwargs)
                        improvement = extract_improve(response)
                        if improvement: break
                        elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...",tab=1)
                    if improvement: # 数据处理
                        # improvement = self.ActID_Plotline(improvement,Act_ID) # 输出有时会把ActID给去掉，这里补上
                        if refined_plot_lines[Act_ID-1] == improvement: # 没有做改动
                            self.info(f"{task}(): Improve '{location}' has no changes. Delete.(from refinement {idx+1}/{cnt1})",tab=1)
                        else:
                            refinements1_ok.append(refinement)
                            improvements1_ok.append(improvement)
                            refined_plot_lines[Act_ID-1] = improvement
                            self.info(f"{task}(): Improve '{location}' succeeded. (from refinement {idx+1}/{cnt1})",tab=1)
                    else: self.warning(f"{task}(): Improve '{location}' output JSON broken. Delete.(from refinement {idx+1}/{cnt1})",tab=1)
                else: self.warning(f"{task}(): Location '{location}' broken! Delete. (from refinement {idx+1}/{cnt1})",tab=1)
            
            self.info(f"{task}(): Got {len(improvements1_ok)} improvements of Plot lines.")

            #--------------------------------------#
            # 3.【对character profiles提出修改意见】#
            task = f'self_refine_profiles'
            self.info(f"{task}(): Getting refinements of Character Profiles...")
            # 提取需要用到的信息
            history2 = get_history_refinements(all_refinements2)
            # 导入prompt
            system_prompt = SYSTEM_REFINE2.format(output_format = refine_format2)
            user_prompt = USER_REFINE2.format(
                profiles = refined_profiles, history2 = history2, plot_lines = refined_plot_lines
            )
            # 获取模型响应
            for attempt in range(self.max_retry + 1):
                response = self.get_response(system_prompt, user_prompt, f"{task}_{index}", **generate_kwargs)
                refinements2 = extract_refine(response)
                if refinements2: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...")
            refinements2,cnt2_info = refinements_sort(refinements2)
            cnt2_ = len(refinements2)
            cnt2 = max(0,cnt2_-index+1)
            self.info(f"{task}(): Got {cnt2_} refinements({cnt2_info}) for Character Profiles. Will try to improve the first {cnt2}.")

            #------------------------------#
            # 4.【根据修改意见提升profiles】#
            task = f'self_improve_profiles'
            self.info(f"{task}(): Getting improvements of Character Profiles...")

            refinements2_ok,improvements2_ok = [],[]
            for idx in range(cnt2):
                refinement = refinements2[idx]
                # 提取需要用到的信息
                location = refinement.get("location") #形如'Act x'的字符串
                Act_ID = self.get_Act_ID(location) #将字符串解析成数字
                context = self.get_context(Act_ID)
                if context:
                    # 导入prompt
                    system_prompt = SYSTEM_IMPROVE2.format()
                    user_prompt = USER_IMPROVE2.format(
                        suggestion=refinement, context=context, plot_line=refined_plot_lines[Act_ID-1],
                        profile=refined_profiles[Act_ID-1], output_format=improve_format2
                    )
                    # 获取模型响应
                    for attempt in range(self.max_retry + 1):
                        response = self.get_response(system_prompt, user_prompt, f"{task}_{index}-{idx+1}", **generate_kwargs)
                        improvement = extract_improve(response)
                        if improvement: break
                        elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...",tab=1)
                    if improvement: # 数据处理
                        # improvement = self.ActID_Profile(improvement,Act_ID) # 输出有时会把ActID给去掉，这里补上
                        if refined_profiles[Act_ID-1] == improvement: # 没有做改动
                            self.info(f"{task}(): Improve '{location}' has no changes. Delete.(from refinement {idx+1}/{cnt2})",tab=1)
                        else:
                            refinements2_ok.append(refinement)
                            improvements2_ok.append(improvement)
                            refined_profiles[Act_ID-1] = improvement
                            self.info(f"{task}(): Improve '{location}' succeeded. (from refinement {idx+1}/{cnt2})",tab=1)
                    else: self.warning(f"{task}(): Improve '{location}' output JSON broken. Delete.(from refinement {idx+1}/{cnt2})",tab=1)
                else: self.warning(f"{task}(): Location '{location}' broken! Delete. (from refinement {idx+1}/{cnt2})",tab=1)
            
            self.info(f"{task}(): Got {len(improvements2_ok)} improvements of Character Profiles.")
            
            #------------------------------#
            if not improvements1_ok and not improvements2_ok:
                self.info(f"self_refine(): No Improvements in Round {index}/{refine_rounds}. Iteration stopped.")
                break
            
            all_refinements1.extend(refinements1_ok)
            all_refinements2.extend(refinements2_ok)

            refine_info.append({ # 每个优化轮次的中间信息列表
                "plot_lines_refinements": refinements1, # 本轮次中提出的原始建议列表
                "profiles_refinements": refinements2, # 本轮次中提出的原始建议列表

                "plot_lines_refinements_ok": refinements1_ok, # 成功执行修改的建议列表
                "profiles_refinements_ok": refinements2_ok, # 成功执行修改的建议列表

                "plot_lines_improvements": improvements1_ok, # 每次执行修改后的单幕结果列表
                "profiles_improvements": improvements2_ok, # 每次执行修改后的单幕结果列表

                "refined_plot_lines": refined_plot_lines, # 本轮次优化后的全幕情节链
                "refined_profiles": refined_profiles # 本轮次优化后的全幕角色传记
            })
            # self.save_data_3_temp(refine_info)

        return refine_info

    def calc_refine_times(self, refine_info):
        cnt1 = [len(refine['plot_lines_refinements']) for refine in refine_info]
        cnt2 = [len(refine['profiles_refinements']) for refine in refine_info]
        sum1,sum2 = sum(cnt1),sum(cnt2)
        self.info(f"Total get refinements: {str(cnt1)} + {str(cnt2)} = {sum1} + {sum2} = {sum1+sum2}")

        cnt1 = [len(refine['plot_lines_improvements']) for refine in refine_info]
        cnt2 = [len(refine['profiles_improvements']) for refine in refine_info]
        sum1,sum2 = sum(cnt1),sum(cnt2)
        self.info(f"Total use improvements: {str(cnt1)} + {str(cnt2)} = {sum1} + {sum2} = {sum1+sum2}")
        return 

    def calc_events(self, plot_lines):
        cnt = 0
        for act in plot_lines:
            if act.get("Plot Chains List"):
                cnt+= len(act["Plot Chains List"])
        return cnt

    def get_new_edges(self, plots_edges, new_plots_edges, limit=0):
        old_edges = [(edge[0],edge[1],edge[2]) for edge in plots_edges]
        new_edges = []
        for edge in new_plots_edges:
            new_ = (edge[0],edge[1],edge[2])
            if new_ not in old_edges:
                new_edges.append(edge)
        return new_edges
    def get_important_edges(self,new_plots_edges):
        new_edges = []
        for edge in new_plots_edges:
            act1,act2 = edge[0].split('-')[0],edge[1].split('-')[0]
            if act1 == act2 and edge[2]=="LOW": continue #关系较弱的幕内边跳过
            new_edges.append(edge)
        return new_edges
    def add_new_edges(self,plots_edges, new_plots_edges):
        plots_edges.extend(new_plots_edges)
        return plots_edges
    def get_history_edges(self,plots_edges):
        history = ""
        for i in range(len(plots_edges)):
            edge = plots_edges[i]
            history += f"RELATIONSHIP {i+1}:\n- Events: {edge[0]} -> {edge[1]}\n- Strength: {edge[2]}\n- Details: {edge[3]}\n"
        return history
    def sort_EventID(self,edges): # 按照编号排序
        def get_ID(event):
            nums = event.replace("Act ","").replace("Event ","").split('-')
            return int(nums[0]), int(nums[1])
        def cmp(i,j): #小于号比较函数
            a,b = get_ID(edges[i][0])
            c,d = get_ID(edges[i][1])
            e,f = get_ID(edges[j][0])
            g,h = get_ID(edges[j][1])
            if a!=e: return a<e
            elif b!=f: return b<f
            elif c!=g: return c<g
            else: return d<h
        cnt = len(edges)
        for i in range(cnt):
            for j in range(cnt - i - 1):
                if cmp(i,j):
                    edges[i], edges[j] = edges[j], edges[i]
        return edges
    def fix_graph(self,plot_lines,plots_edges): #破环为链
        events = []
        for act in plot_lines:
            for event in act.get('Plot Chains List',[]):
                events.append(event['Event ID'])
        # print("事件读取完成")
        
        edges = []
        for edge in plots_edges: #有时会把不存在的点拿来建边
            if edge[0] in events and edge[1] in events:
                edges.append(copy.deepcopy(edge))
        #print("边读取完成")

        # 初始化节点度数（度数越小，分数越小，优先级越高）
        deg = {}
        for event in events: deg[event]=0 #每个节点的度数（包括入边和出边）
        for edge in edges:
            deg[edge[0]] += 1
            deg[edge[1]] += 1
        
        # print("度数完成")
        #从小到大排序，第一关键字为因果关系强度，第二关键字为节点度数
        def cmp(i,j): #小于号比较函数
            score = {"HIGH": 1, "MEDIUM": 2, "LOW": 3} #因果关系越强，分数越小，优先级越高
            if score[edges[i][2]] != score[edges[j][2]]:
                return score[edges[i][2]] < score[edges[j][2]]
            else: return deg[edges[i][0]]+deg[edges[i][1]] < deg[edges[j][0]]+deg[edges[j][1]]
        cnt = len(edges) #边数
        for i in range(cnt): #冒泡排序
            for j in range(cnt-i-1): #i在j的右边
                if cmp(i,j): #如果i的分数小于j（优先级更高，放到前面）
                    edges[i],edges[j] = edges[j],edges[i]
        # print("排序完成")
        #不断贪心加边
        final_edges = []
        to,fr = {},{}
        for event in events: to[event] = {event} #每个节点能到达哪些节点
        for event in events: fr[event] = {event} #每个节点能被哪些节点到达
        for edge in edges:
            a,b = edge[0],edge[1]
            if b in fr[a]: continue #如果b能到达a，说明这条边加进来会形成环
            for x in fr[a]: #能到达a的点（包括a）
                for y in to[b]: #b能到达的点（包括b）
                    to[x].add(y)
                    fr[y].add(x)
            final_edges.append(edge)
        final_edges = self.sort_EventID(final_edges)
        return final_edges
    def extract_DAG(self, **generate_kwargs): # 根据事件因果关系构建全幕情节链的DAG
   
        if self.loaded_DAG:
            # task = f'extract_DAG'
            # plot_lines = self.plot_lines
            # plots_edges = self.plots_edges
            # plots_edges = self.build_graph(plot_lines,plots_edges) #破环为链
            # self.info(f"{task}(): Total Got {len(plots_edges)} final edges.")
            # self.loaded_DAG=False
            # return plots_edges
            return self.plots_edges
        if generate_kwargs.get('temperature') == None:  # 温度默认使用0
            generate_kwargs['temperature'] = 0

        if not self.refined_plot_lines: raise Exception("Plot_lines is empty!!!")

        task = f'extract_DAG'
        plot_lines = self.refined_plot_lines
        # G = self.build_graph(plot_lines)
        self.info(f"{task}(): Total {len(plot_lines)} Acts, {self.calc_events(plot_lines)} Events.")
        plots_edges = []

        max_round = len(plot_lines)*2
        for index in range(max_round):
            # 导入prompt
            system_prompt = SYSTEM_DAG.format(output_format=dag_format)
            user_input = USER_DAG.format(plot_lines=plot_lines,plots_edges=self.get_history_edges(plots_edges))

            for attempt in range(self.max_retry+1):
                # 获取模型响应
                response = self.get_response(system_prompt, user_input, f"{task}_{index+1}", **generate_kwargs)
                new_plots_edges = extract_DAG(response)
                if new_plots_edges: break
                elif attempt < self.max_retry: self.warning(f"{task}(): Failed extract_content! Retry {attempt+1}/{self.max_retry}...",tab=1)
            is_limit = 1 if index>0 else 0
            new_edges = self.get_new_edges(plots_edges, new_plots_edges, is_limit)
            important_new_edges = new_edges if index==0 else self.get_important_edges(new_edges)
            plots_edges = self.add_new_edges(plots_edges, important_new_edges)
            self.info(f"{task}(): Got {len(new_edges)} edges in the {index+1}/{max_round} round. Added {len(important_new_edges)} of them (Total {len(plots_edges)} now).",tab=1)
            draw_DAG(plots_edges, self.output_save_dir+'graph_data/', f'Plots DAG_{index+1}', show=False)
        
        self.info(f"{task}(): Total Got {len(plots_edges)} edges.")
        return plots_edges

    # 【主函数】：
    def extract_overall_elements(self, chapter_nums: int = 1, chapter_windows: int = 2, refine_rounds: int = 3, ablation = None, **generate_kwargs):
        """
        输入: 小说全文字符串
        输出: 包含整体剧情走向、关键情节、人物性格、关系变化线的字典
        """

        self.global_elements = {}
        self.ablation = ablation

        plot_lines = []
        try:
            plot_lines = self.extract_plot_lines(chapter_nums, chapter_windows, **generate_kwargs)
            self.plot_lines = plot_lines
            if self.loaded_plot_lines == False: self.save_data_1(plot_lines)# 如果没有加载到数据，则保存当前数据
            self.info("Successfully extract plot_lines.")
        except Exception as e: self.warning(f"Failed to extract plot_lines. Error message: {str(e)}")

        profiles = []
        try:
            profiles = self.extract_profiles(**generate_kwargs)
            self.profiles = profiles
            if self.loaded_profiles == False: self.save_data_2(profiles)# 如果没有加载到数据，则保存当前数据
            self.info("Successfully extract profiles.")
        except Exception as e: self.warning(f"Failed to extract profiles. Error message: {str(e)}")

        refine_info = []
        try:
            refine_info = self.self_refine(refine_rounds = refine_rounds,**generate_kwargs)
            self.refine_info = refine_info
            if self.loaded_refine == False: self.save_data_3(refine_info)# 如果没有加载到数据，则保存当前数据
            self.info(f"Successfully self refine {len(refine_info)} rounds.")
            self.calc_refine_times(refine_info)
        except Exception as e: self.warning(f"Failed to self refine. Error message: {str(e)}")

        self.refined_plot_lines = plot_lines if not refine_info else refine_info[-1]["refined_plot_lines"]
        self.refined_profiles = profiles if not refine_info else refine_info[-1]["refined_profiles"]
        if self.refined_plot_lines and self.refined_profiles:
            self.save_data_4(self.refined_plot_lines,self.refined_profiles)
        
        plots_edges = {}
        if ablation!="PlotGraph":
            try:
                plots_edges = self.extract_DAG(**generate_kwargs)
                self.plots_edges = plots_edges
                if self.loaded_DAG == False: self.save_data_5(plots_edges)# 如果没有加载到数据，则保存当前数据
                self.info(f"Successfully build DAG with {len(plots_edges)} edges and {self.calc_events(self.refined_plot_lines)} points.")
            except Exception as e: self.warning(f"Failed to build DAG. Error message: {str(e)}")

            try:
                plots_edges = self.fix_graph(self.refined_plot_lines,plots_edges) #破环为链
                self.info(f"fix_graph(): Total Got {len(plots_edges)} final edges.")
                self.plots_edges = plots_edges
                self.save_data_5(plots_edges)
                self.info(f"Successfully fix graph with {len(plots_edges)} edges and {self.calc_events(self.refined_plot_lines)} points.")
            except Exception as e: self.warning(f"Failed to fix graph. Error message: {str(e)}")

        if self.plot_lines and self.profiles and (self.refine_info or refine_rounds==0) and (self.plots_edges or ablation=="PlotGraph"):
            self.global_elements = {
                "plot_lines": plot_lines, #初始提取出的情节链
                "profiles": profiles, #初始提取出的角色传记
                "refine_info": refine_info, #每轮优化的中间信息
                "refined_plot_lines": self.refined_plot_lines, #优化后的最终情节链
                "refined_profiles": self.refined_profiles, #优化后的最终角色传记
                "plots_edges": plots_edges, #事件DAG图
                "raw_chapters": self.raw_chapters
            }
            self.save_data_all(self.global_elements)
            self.info("Successfully extract_overall_elements.")
            return {'global_elements': self.global_elements, 'message': 'Success'}
        else: return {'global_elements': self.global_elements, 'message': 'Fail'}

    # 【数据储存和展示】：
    def save_data_0(self,upload_novel):
        path = self.output_save_dir + 'json_data/'
        save_json(upload_novel,path,'upload_novel')
        self.datasl(f"Data 'upload_novel.json' has been saved to '{path}'.")
        cnt=len(upload_novel['chapters'])
        full_text=f"The Uploaded Novel of '{self.basename}':\n"
        for i in range(cnt):
            chapter_text=upload_novel['chapters'][i]
            full_text+=f"""\n[Chapter {i}]:\n{chapter_text}\n"""
        save_file(full_text,self.output_save_dir + 'docx_data/','upload_novel','docx')
        
    def save_data_1(self, plot_lines):
        path = self.output_save_dir + 'json_data/'
        save_json(plot_lines, path, 'plot_lines')
        self.info(f"Data 'plot_lines.json' has been saved to '{path}'.")

    def save_data_2(self, profiles):
        path = self.output_save_dir + 'json_data/'
        save_json(profiles, path, 'profiles')
        self.info(f"Data 'profiles.json' has been saved to '{path}'.")

    def save_data_3(self, refine_info):
        path = self.output_save_dir + 'json_data/'
        save_json(refine_info, path, 'refine_info')
        self.info(f"Data 'refine_info.json' has been saved to '{path}'.")
    
    # def save_data_3_temp(self, refine_info):
    #     path = self.output_save_dir + 'json_data/'
    #     save_json(refine_info, path, 'refine_info_temp')

    def save_data_4(self,refined_plot_lines,refined_profiles):
        path = self.output_save_dir + 'json_data/'
        save_json(refined_plot_lines,path,'refined_plot_lines')
        self.datasl(f"Data 'refined_plot_lines.json' has been saved to '{path}'.")
        save_json(refined_profiles,path,'refined_profiles')
        self.datasl(f"Data 'refined_profiles.json' has been saved to '{path}'.")
    
    def save_data_5(self, plots_edges):
        draw_DAG(plots_edges, self.output_save_dir+'graph_data/', 'Plots DAG of '+self.config.output_basename, show=False)
        path = self.output_save_dir + 'json_data/'
        save_json(plots_edges, path, 'plots_edges')
        self.datasl(f"Data 'plots_edges.json' has been saved to '{path}'.")

    def save_data_all(self,global_elements):
        path = self.output_save_dir + 'json_data/'
        save_json(global_elements,path,'global_elements')
        self.datasl(f"Data 'global_elements.json' has been saved to '{path}'.")
        