import ast
import os
import re
import sys
import json
from datetime import datetime

def draw_DAG(edges,save_dir,title="Directed Acyclic Graph",show=False):
    os.makedirs(save_dir,exist_ok=True)
    import matplotlib.pyplot as plt
    import networkx as nx
    plt.clf() #清空
    G = nx.DiGraph()# 创建一个有向图
    G.clear() # 清空图
    # edges = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('B', 'D'), ('C', 'D')]# 添加节点和边
    add_edges = [(edge[0].replace("Act ","").replace("Event ",""), edge[1].replace("Act ","").replace("Event ","")) for edge in edges]
    G.add_edges_from(add_edges)
    # 绘制DAG
    plt.title(title)
    nx.draw(G, pos=nx.spring_layout(G), with_labels=True, arrowsize=10, node_size=0, font_size=8,font_color='blue', edge_color='red')
    plt.savefig(save_dir+f'{title}.png', dpi=300)
    if show==True: plt.show()

    save_text_to_dirtxt("\n".join([f"""\"{edge[0]}\" \"{edge[1]}\" \"{edges[idx][2][0]}\"""" for idx,edge in enumerate(add_edges)])
                        ,save_dir,f'data_{title}')

class CodeParser:
    @classmethod
    def parse_block(cls, block: str, text: str) -> str:
        """
        Parse a block of code from a given text.
        
        :param block: The name of the block to parse.
        :param text: The text to parse.
        :return: The parsed code block.
        """
        blocks = cls.parse_blocks(text)
        for k, v in blocks.items():
            if block in k:
                return v
        return ""

    @classmethod
    def parse_blocks(cls, text: str):
        # 首先根据"##"将文本分割成不同的block
        blocks = text.split("##")

        # 创建一个字典，用于存储每个block的标题和内容
        block_dict = {}

        # 遍历所有的block
        for block in blocks:
            # 如果block不为空，则继续处理
            if block.strip() == "":
                continue
            if "\n" not in block:
                block_title = block
                block_content = ""
            else:
                # 将block的标题和内容分开，并分别去掉前后的空白字符
                block_title, block_content = block.split("\n", 1)
            block_dict[block_title.strip()] = block_content.strip()

        return block_dict

    @classmethod
    def parse_code(cls, block: str, text: str, lang: str = "") -> str:
        """
        Parse a code block from a given text.

        If `block` is not empty, parse the block from the text first, then parse the code block from the parsed block.
        If `block` is empty, parse the code block from the text directly.

        :param block: The name of the block to parse, default is empty.
        :param text: The text to parse.
        :param lang: The language of the code block, default is empty.
        :return: The parsed code block.
        """
        
        if block:
            text = cls.parse_block(block, text)
        pattern = rf"```{lang}.*?\s+(.*?)```"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            code = match.group(1)
        else:
            # print(f"{pattern} not match following text:")
            # print(text)
            raise Exception(f"{pattern} not match following text: {text[:min(200,len(text))]}")
            return text  # just assume original text is code
        return code

    @classmethod
    def parse_str(cls, block: str, text: str, lang: str = ""):
        code = cls.parse_code(block, text, lang)
        code = code.split("=")[-1]
        code = code.strip().strip("'").strip('"')
        return code

    @classmethod
    def parse_file_list(cls, block: str, text: str, lang: str = "") -> list[str]:
        # Regular expression pattern to find the tasks list.
        code = cls.parse_code(block, text, lang)
        # print(code)
        pattern = r"\s*(.*=.*)?(\[.*\])"

        # Extract tasks list string using regex.
        match = re.search(pattern, code, re.DOTALL)
        if match:
            tasks_list_str = match.group(2)

            # Convert string representation of list to a Python list using ast.literal_eval.
            tasks = ast.literal_eval(tasks_list_str)
        else:
            raise Exception
        return tasks


def print_config(instance): #输出dataclass的变量
    from dataclasses import fields
    print("[config]")
    for field in fields(instance): # 输出instance的所有变量的名称和值
        print(f"    {field.name}: {getattr(instance, field.name)}")
    print("")

def read_api_keys_from_file(file_path):
    """
    读取API密钥文件，并返回一个API密钥列表。
    
    :param file_path: 包含API密钥的文本文件路径
    :return: 包含API密钥的列表
    """
    try:
        with open(file_path, 'r', errors='ignore') as file:
            api_keys = [line.strip() for line in file if line.strip()]
    except FileNotFoundError:
        raise ValueError(f"[FBI Warning] 文件路径有误！{file_path} 不存在")
    except Exception as e:
        raise ValueError(f"读取文件时发生错误: {e}")
    
    return api_keys

# import pdfminer
from docx import Document
from pdfminer.high_level import extract_text
class Read_File: # 读取小说文档
    def __init__(self):
        pass
    def read_word_to_text(self,filepath):
        full_text = ""
        # if filepath!=None and filepath!="" and os.path.exists(filepath)==True:
        doc = Document(filepath)
        for para in doc.paragraphs:
            full_text+=para.text+'\n'
        return full_text
    def read_pdf_to_text(self,filepath):
        full_text = ""
        # if filepath!=None and filepath!="" and os.path.exists(filepath)==True:
        pdf_file = open(filepath,'rb', errors='ignore')
        full_text = extract_text(pdf_file)
        return full_text
    def read_txt_to_text(self,filepath):
        full_text = ""
        # if filepath!=None and filepath!="" and os.path.exists(filepath)==True:
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            lines = f.readlines()
            for line in lines:
                full_text+=line
        return full_text
    def autoread(self,path):
        if path!=None and path!="" and os.path.exists(path)==True:
            _type=os.path.splitext(path)[1]
            if _type=='.docx' or _type=='.DOCX':
                return self.read_word_to_text(path)
            if _type=='.txt' or _type=='.TXT':
                return self.read_txt_to_text(path)
            if _type=='.pdf' or _type=='.PDF':
                return self.read_pdf_to_text(path)
        else:
            print(f"[FBI Warning] 文件路径有误！{path} 不存在")
        return ""

def read_file(path): # 读取小说文档
    """
    读取小说文档并返回处理后的文本。

    Parameters:
        path (str): 小说文档的路径

    Returns:
        str: 处理后的文本
    """
    read=Read_File()
    return replace_chinese_punc(read.autoread(path)) #将中文标点符号替换为英文符号

def save_file(text,path,filename,type='docx'):
    """
    Saves a given text to a file at a specified path with a specified filename.

    Parameters:
        text (str): The text to be saved.
        path (str): The path where the file will be saved.
        filename (str): The name of the file.
        type (str): The type of the file, either 'docx' or 'txt'. Defaults to 'docx'.

    Returns:
        str: The path to the saved file, or None if the file type is not supported.
    """
    if os.path.exists(path)==False:
        os.makedirs(path)
    if type == 'docx':
        return save_text_to_dirword(text,path,filename)
    if type == 'txt':
        return save_text_to_dirtxt(text,path,filename)
    return None

def read_json(path,filename=None): #读取json文件
    """
    Reads a JSON file from a specified path and filename.

    Parameters:
        path (str): The path to the JSON file.
        filename (str): The name of the JSON file. If None, the path must be a full path to the file.

    Returns:
        dict: The JSON data read from the file.

    Raises:
        FileNotFoundError: If the file does not exist or the JSON data is broken.
    """
    if path!=None and filename!=None:
        if path.endswith('/')==True: path+=filename
        else: path=os.path.join(path,filename)
    if path!=None and path!="" and os.path.exists(path)==True:
        try:
            data={}
            with open(path,'r', errors='ignore') as file:
                data = file.read()
            json_data=json.loads(data)
            return json_data
        except: raise FileNotFoundError(f"Json data is broken.")
    else: raise FileNotFoundError(f"Not found '{path}'.")

def save_json(json_data,path,filename): #保存json数据
    """
    保存json数据到指定路径。

    参数：
        json_data (dict): 需要保存的json数据
        path (str): 保存路径
        filename (str): 保存文件名

    返回：
        str: 保存文件路径
    """
    os.makedirs(path,exist_ok=True)
    file_path=os.path.join(path,f'{filename}.json')
    with open(file_path,'w', errors='ignore') as file:
        file.write(json.dumps(json_data,indent=4))
    return file_path

def save_text_to_word(text, filename):
    # Create a new Document
    doc = Document()
    # Add text to the Document
    # if text:
    for line in text.strip().split('\n'):
        doc.add_paragraph(line)
    # Save the Document
    doc.save(filename)

def save_text_to_dirword(text,save_dir,name): #保存文本text到save_dir/name.docx
    file = os.path.join(save_dir, f"{name}.docx")
    save_text_to_word(text, file)
    return file
def save_text_to_dirtxt(text,save_dir,name): #保存文本text到save_dir/name.txt
    file_path = os.path.join(save_dir, f"{name}.txt")
    with open(file_path,'w', encoding='utf-8', errors='ignore') as file:
        file.write(text)
    return file_path

def count_words(text=""): #统计单词数+标点符号数
    from re import findall
    if text==None: text=""
    return len(text.split())+len(findall(r'[\,\.\!\?\:\;\'\"\(\)]',text))

def var2dict(**kwargs):
    return kwargs

def contains_chinese(text): #判断文本中是否存在中文
    if text==None:
        return False
    return re.search(r'[\u4e00-\u9fff]', text) is not None or re.search(r'[\u3400-\u4dbf]', text) is not None

def contain_english_character(text): #判断文本中是否存在英文字母
    if text==None:
        return False
    return bool(re.search(r'[a-zA-Z]', text))

def contain_english_punc(text): #判断文本中是否存在英文符号
    if text==None:
        return False
    english_punc = {',','.','!','?',':',';','"','"',"'","'",'(', ')','[',']','-','...',' '}
    return any([punc in text for punc in english_punc])

def text_punc_fix(text): #去掉首尾英文符号
    while text and contain_english_punc(text[0]):
        text = text[1:].strip()
    while text and contain_english_punc(text[-1]):
        text = text[:-1].strip()
    return text

def text_word_fix(text): #去掉首尾单词
    while text and contain_english_character(text[0]):
        text = text[1:].strip()
    while text and contain_english_character(text[-1]):
        text = text[:-1].strip()
    return text

def replace_chinese_punc(text): #将中文标点符号替换为英文符号
    chinese_punc = {
        '，': ',',
        '。': '.',
        '！': '!',
        '？': '?',
        '：': ':',
        '；': ';',
        '“': '"',
        '”': '"',
        '‘': "'",
        '’': "'",
        '（': '(',
        '）': ')',
        '【': '[',
        '】': ']',
        '—': '-',
        '…': '...',
        '　': ' ', #中文全角空格
        'é': 'e', #\u00e9
        '  ': ' ',#连续两个空格
    }
    for chinese, english in chinese_punc.items():
        text = text.replace(chinese, english)
    return text 

# _chapter_label=["EPISODE","Episode","CHAPTER","Chapter","PART","BAB","SECTION","Part","Section"]
_chapter_label=["CHAPTER","Chapter"]
def replace_chapter_label(text,new_label='\nChapter '): # 将全文中各种稀奇古怪的章节标签替换成'Chapter'
    for label in _chapter_label:
        text = text.replace(label,new_label)
    return text
_chapter_number = [
    '1','2','3','4','5','6','7','8','9','0',
    '一','二','三','四','五','六','七','八','九','十',
    'one','two','three','four','five','six','seven','eight','nine','ten',
    'eleven','twelve','thirteen','fourteen','fifteen','sixteen','seventeen','eighteen','nineteen',
    'twenty','thirty','forty','fifty','sixty','seventy','eighty','ninety',
    'hundred-and','hundredand','hundred',
    'first','second','third','fourth','fifth','sixth','seventh','eighth','ninth','tenth',
]
def delete_chapter_num_title(text): # 删除单章节首行的 'Chapter {number} {title}' 或 '{number} {title}'（title在次行不会删，防止删掉正文）
    text=text.strip()
    cut=min(len(text),100)
    first=text[0:cut]
    text=text[cut:]
    while first and first[-1] in ['\n','\t']:
        text=first[-1]+text
        first=first[:-1]

    words=first.split('\n')
    # print(words)
    have_label=False
    for char in _chapter_label:
        if char in words[0]:
            have_label=True
    for char in _chapter_number:
        if char in words[0].lower():
            have_label=True
    if have_label:
        words=words[1:]
    new_first=""
    for word in words:
        new_first+=word+"\n"
    return new_first[:-1]+text #new_first末尾会多出一个'\n'

def delete_chapter_num(text): # 删除单章节最前面的 'Chapter {number}' 或 '{number}'
    text=text.strip()
    cut=min(len(text),100)
    first=text[0:cut]
    text=text[cut:]
    while first and first[-1] in ['\n','\t']:
        text=first[-1]+text
        first=first[:-1]

    words=first.split(' ')
    # print(words)
    fin,last_hundred=False,False
    new_first=""
    for word in words:
        if fin==False:
            if word=='-' or word==':' or word=='.' or word=='_':
                word=""
            word=word.strip()
            if word:
                flag=False
                for char in _chapter_label:
                    if char in word:
                        flag=True
                        word=word.replace(char,'').replace('-',' ').replace(':','').replace('.','').replace('-','').strip()
                if last_hundred==True:
                    if 'and' in word.lower():
                        word=word.lower().replace('and','').strip()
                        flag=True
                last_hundred= 'hundred' in word.lower()
                for char in _chapter_number:
                    if char in word.lower():
                        flag=True
                        word=word.lower().replace(char,'').replace('-',' ').replace(':','').replace('.','').replace('-','').strip()
                if flag==False:
                    fin=True
        
        if fin==True:
            if word in ['\n','\t']:
                new_first+=word
        if word:
            new_first+=word+' '
    return new_first+text

class Logger(object): #保存日志
    def __init__(self, filepath, stream=sys.stdout):
        self.terminal = stream
        self.log = open(filepath, 'w', errors='ignore')
        self.previousMsg = None
        sys.stdout = self
 
    def write(self, message):
        if self.previousMsg == None or "\n" in self.previousMsg:
            topMsg = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " :\n"
            # self.terminal.write(topMsg) # 显示前端
            # self.log.write(topMsg) # 输出日志
 
        if isinstance(message, str):
            self.previousMsg = message
        if self.previousMsg == None:
            self.previousMsg = ""
 
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush() # 刷新缓存区，即将缓存区的内容立即写入到文件(否则就会运行完程序在写入)
 
    def flush(self):
        pass

def connect_logging(save_dir,name="",unique_id=None): # 保存日志
    logfileName = unique_id if unique_id!=None else datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    os.makedirs(save_dir, exist_ok=True)
    filepath=os.path.join(save_dir,f"{logfileName}_{name}.log")
    sys.stdout, sys.stderr = Logger(filepath),Logger(filepath) # 记录正常的print信息和traceback信息
    return filepath

def get_unique_id():
    # _now = datetime.now()
    # return "-".join([f'{_now.year}',f'{_now.month}',f'{_now.day}',f'{_now.hour}',f'{_now.minute}',f'{_now.second}'])
    return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

def sanitize_path(path):
    # 定义非法字符及其替换字符
    # illegal_chars = r'[<>:"/\\|?*]'
    illegal_chars = r'[.<>:"|?*]'
    replacement_char = '_'
    
    # 使用正则表达式替换非法字符
    sanitized_path = re.sub(illegal_chars, replacement_char, path)
    
    return sanitized_path.replace("'","")

def path_basename(path):
    return sanitize_path(os.path.basename(path))