
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_from_disk
import os
import argparse
from torch.amp import autocast

import os
import json
from datasets import load_from_disk, Dataset
from torch.amp import autocast
from multiprocessing import Pool  # 用于多任务处理
import multiprocessing
import subprocess
import csv

import os
import re
import torch
import subprocess
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import jsonlines
import pyverilog.vparser.ast as vast
from pyverilog.vparser.parser import parse
import signal

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
compile_pass_list = []
sim_pass_list = []




def count_tokens(text, tokenizer):
    """返回给定文本的 token 数量"""
    return len(tokenizer.encode(text))

model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/original_model/Qwen2.5-7B-Instruct-1M"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_dataset2_eval_test_1/checkpoint-400"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_400continue_similarity_test_2/checkpoint-468"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_468continue_similarity_test_1/checkpoint-200"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_SFTcontinue_penalty/checkpoint-468"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_SFTcontinue_more_generate_penalty_test_2/checkpoint-500"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_huge500_continue_penalty_test_1/checkpoint-468"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_huge500_continue_penalty_test_total/checkpoint-400"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_huge500_continue_penalty_test_total_2/checkpoint-300"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_SFTcontinue_more_generate_penalty_test_2/checkpoint-500"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way2_huge500_continue_test/checkpoint-200"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_400continue_similarity_test_2/checkpoint-468"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way3_countinue_test/checkpoint-200"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/trained_model/Qwen2.5-7B-Instruct-1M_dataset2_SFT_2_eval_train/checkpoint-156"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way3_countinue_test/checkpoint-600"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way4_countinue_test/checkpoint-300"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/SFT_fine_tuned_moddel/f5_try/checkpoint-65"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f6_try"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way7/checkpoint-25"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way8_countinue_test/checkpoint-200"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/strong_SFT_model/best_data_SFT_2"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/Qwen2.5-7B-Instruct-1M_SFTcontinue_more_generate_penalty_test_2/checkpoint-500"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/last_model/NOPPA_TRY_1/checkpoint-200"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/more_generate_model/f_way4_countinue_test/checkpoint-300"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/disSFT_model/Qwen2.5_test_1"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/After_work/model/huge_data_test_1/checkpoint-300"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/After_work/model/stage_3_try_1/checkpoint-78"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/After_work/model/stage_3_try_2/checkpoint-25"]
# model_path = ["/mnt/sdb/workspace/zxy_workspace/Good_model/good_model_2"]
#model_path = ["/mnt/sdb/workspace/zxy_workspace/GRPO/After_work/model/stage_3_test_1/checkpoint-25"]



def timeout_handler(signum, frame):
    """处理超时信号的函数"""
    raise TimeoutError("Function ran out of time")

def validate_verilog_code(code):
    """检验Verilog代码是否能够正确生成AST"""
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(5)  # 5秒后触发SIGALRM信号
    try:
        ast, _ = parse([code])
        if ast is None:
            return None
        return ast
    except TimeoutError:
        return None
    except Exception as e:
        return None
    finally:
        signal.alarm(0)

def count_nodes(node):
    """统计AST中的节点总数"""
    count = 1  # 当前节点
    if hasattr(node, 'children'):
        for child in node.children():
            count += count_nodes(child)
    return count

def compare_ast(node1, node2):
    """递归比较两个AST节点的相似性"""
    # 检查节点类型是否相同
    if type(node1) != type(node2):
        return False

    # 检查节点的属性是否相同
    if hasattr(node1, '__dict__') and hasattr(node2, '__dict__'):
        for attr in node1.__dict__:
            if attr == 'children':
                continue
            value1 = getattr(node1, attr)
            value2 = getattr(node2, attr)
            if isinstance(value1, tuple) and isinstance(value2, tuple):
                if len(value1) != len(value2):
                    return False
                for i in range(len(value1)):
                    if isinstance(value1[i], object) and isinstance(value2[i], object):
                        if not compare_ast(value1[i], value2[i]):
                            return False
                    elif value1[i] != value2[i]:
                        return False
            elif isinstance(value1, object) and isinstance(value2, object):
                if not compare_ast(value1, value2):
                    return False
            elif value1 != value2:
                return False

    # 检查子节点
    if hasattr(node1, 'children') and hasattr(node2, 'children'):
        children1 = list(node1.children())
        children2 = list(node2.children())
        # 如果子节点数量不同，直接返回False
        if len(children1) != len(children2):
            return False
        # 递归比较每个子节点
        for c1, c2 in zip(children1, children2):
            if not compare_ast(c1, c2):
                return False

    return True

def count_matched_nodes(node1, node2):
    """统计两个AST中匹配的节点数量"""
    if node1 is None or node2 is None:
        return 0  # 如果任一节点为None，不匹配
    if type(node1) != type(node2):
        return 0  # 节点类型不同，不匹配
    if hasattr(node1, '__dict__') and hasattr(node2, '__dict__'):
        for attr in node1.__dict__:
            if attr == 'children':
                continue
            value1 = getattr(node1, attr)
            value2 = getattr(node2, attr)
            if isinstance(value1, tuple) and isinstance(value2, tuple):
                if len(value1) != len(value2):
                    return 0
                for i in range(len(value1)):
                    if isinstance(value1[i], object) and isinstance(value2[i], object):
                        if not compare_ast(value1[i], value2[i]):
                            return 0
                    elif value1[i] != value2[i]:
                        return 0
            elif isinstance(value1, object) and isinstance(value2, object):
                if not compare_ast(value1, value2):
                    return 0
            elif value1 != value2:
                return 0
    count = 1  # 当前节点匹配
    if hasattr(node1, 'children') and hasattr(node2, 'children'):
        children1 = list(node1.children())
        children2 = list(node2.children())
        if len(children1) != len(children2):
            return 0  # 子节点数量不同，不匹配
        for c1, c2 in zip(children1, children2):
            count += count_matched_nodes(c1, c2)
    return count

def ast_similarity(ast1, ast2):
    """计算两个AST的相似性（基于匹配节点比例）"""
    total_nodes = count_nodes(ast1)
    matched_nodes = count_matched_nodes(ast1, ast2)
    return matched_nodes / total_nodes if total_nodes > 0 else 0.0


def run_vcs():
    work_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/test_vcs"
    vcs_library_path = '/usr/synopsys/L-2016.06/linux/lib'
    current_ld_library_path = os.environ.get('LD_LIBRARY_PATH', '')
    os.environ['LD_LIBRARY_PATH'] = f'{vcs_library_path}:{current_ld_library_path}'
    # 输出确认 LD_LIBRARY_PATH 是否设置正确
    #print("LD_LIBRARY_PATH 已设置为:", os.environ['LD_LIBRARY_PATH'])
    verilog_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/test_vcs/verilog.v"
    clean_command = "make clean"
    testbench_file = "/mnt/sdb/workspace/zxy_workspace/GRPO/test_vcs/testbench.v"
    testbench_file_quoted = f'"{testbench_file}"'
    vcs_command = f'bash -i -c "vcs -sverilog {testbench_file_quoted} {verilog_path} &"'
    # 使用 subprocess.Popen 启动进程并检查是否成功
    try:
        #proc = subprocess.Popen(vcs_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=root)
        proc = subprocess.Popen(vcs_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=work_path)
        stdout, stderr = proc.communicate()  # 等待进程完成
        print(stdout.decode())
        # 检查是否有错误输出
        if proc.returncode != 0 or 'Error' in stdout.decode():
            print(f"编译失败，仿真被跳过")
            return 1
        else:
            print(f"编译成功") 
    except Exception as e:
        print(f"执行 vcs 命令时发生错误: {str(e)}")
    # 使用 subprocess 来执行 make sim 命令并捕获输出
    make_sim_command = "make sim"
    try:
        # 设置超时为10秒，若超过10秒没有完成则强制中断
        process = subprocess.Popen(make_sim_command, shell=True, cwd=work_path, 
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        stdout, stderr = process.communicate(timeout=10)  # 设定10秒的超时时间

        # 输出仿真结果
        output = stdout.decode()
        print(output)
        if "Your Design Passed" in output:
            print(f"√√√√√√√√√√√仿真成功√√√√√√√√√√√√")
            return 3
        else:
            print(f"×××××××××××仿真失败××××××××××××")
            return 2
    except subprocess.TimeoutExpired:
        # 如果超时，输出并标记为仿真失败
        process.kill()
        print(f"仿真超时，跳过该次仿真")
        return 2
    except Exception as e:
        print(f"执行 make sim 命令时发生错误: {str(e)}")
        return 2



def extract_think(text: str) -> str:
    answer = text.split("<think>")[-1]
    answer = answer.split("</think>")[0]
    return answer.strip()


def extract_xml_answer(text: str) -> str:
    answer = text.split("<total_design>")[-1]
    answer = answer.split("</total_design>")[0]
    return answer.strip()


def count_total_design_blocks(text):
    # 定义正则表达式模式
    pattern = r'<total_design>.*?</total_design>'
    
    # 使用 re.findall 查找所有匹配的文本片段
    matches = re.findall(pattern, text, re.DOTALL)
    
    # 返回匹配的数量
    return len(matches)


#RTL_FILE_PATH = "/mnt/sdb/workspace/zxy_workspace/GRPO/data/rtl_data/RTLLM_2level"

def run_tcl_script(tcl_script_path, output_file_path):
    try:
        # 打开文件以写入输出
        with open(output_file_path, 'w') as output_file:
            # 使用 subprocess.run 来运行命令，并将输出重定向到文件
            result = subprocess.run(['dc_shell', '-f', tcl_script_path], check=True, text=True, stdout=output_file, stderr=subprocess.PIPE,timeout=20)
        
        # 如果有标准错误，打印到控制台
        if result.stderr:
            print("标准错误:")
            print(result.stderr)
        
        # 返回命令的返回码
        return result.returncode
    except subprocess.TimeoutExpired as e:
        print(f"命令运行超时，超过10秒未完成")
        return 124


    except subprocess.CalledProcessError as e:
        # 如果命令返回非零状态码，捕获异常并打印错误信息
        print(f"命令执行失败，返回码: {e.returncode}")
        print(f"错误信息: {e.stderr}")
        return e.returncode

def extract_values_from_file(file_path):
    area = None
    delay = None
    power = None
    
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for line in lines:
            # 提取面积值
            if "Total cell area" in line:
                match = re.search(r'(\d+\.\d+)', line)
                if match:
                    if match != 0:
                        area = float(match.group(1))
            
            # 提取延迟值
            if "data arrival time" in line:
                match = re.search(r'(\d+\.\d+)', line)
                if match:
                    if match != 0:
                        delay = float(match.group(1))
            
            # 提取功率值
            if "Total Dynamic Power" in line:
                match = re.search(r'(\d+\.\d+[eE]?[-+]?\d*)', line)
                if match:
                    if match != 0:
                        power = float(match.group(1))
    
    return area, delay, power



SYSTEM_PROMPT = """
Respond in the following format:
<think>
...
</think>
<total_design>
...
</total_design>

...

<total_design>
...
</total_design>
"""




SYSTEM_PROMPT = """
Respond in the following format:
<think>
...
</think>
<total_design>
...
</total_design>

...

<total_design>
...
</total_design>
"""




CREATE_THINKING_PROMPT = """To ensure comprehensive reasoning and accurate responses, please carefully analyze the following query.   Break down the problem into its fundamental components, evaluate potential solutions, and provide a detailed, step-by-step explanation of your thought process.   Avoid skipping any critical thinking steps or providing premature conclusions.   Your response should reflect a deep understanding of the topic and demonstrate logical coherence throughout.
Place your thought process between <think> and </think>"""


SYNTAX_PROMPT = """
You should note that in Verilog, signals of type wire cannot be assigned directly in the always block. A wire type signal is usually used for continuous assignment (assign statement) or the output of a module. A signal of type reg can be assigned a value in the always block.
It is important to note in the design that undefined modules need to be defined first."""

GENERATE_MORE_RTL_PROMPT_2 = """
Generate as many structurally diverse Verilog code as possible, optimized for area, power, and minimizing the worst path delay.The generated verilog code should be placed between <total_design> and </total_design>
Respond in the following format:
<think>
...
</think>
<total_design>
...
</total_design>
...
<total_design>
...
</total_design>
The code for each design should have only one module, and you need to avoid multiple modules in a design.
"""


GENERATE_MORE_RTL_PROMPT = """
Generating as much structurally diverse Verilog code as possible can be optimized from the perspective of area, power, and minimizing the worst path delays. The generated verilog code should be placed between <total_design> and </total_design>, and the generated code optimized from different angles should be placed in different <total_design> and </total_design> tags.
Respond in the following format:
<think>
...
</think>
<total_design>
...
</total_design>
...
<total_design>
...
</total_design>
The code for each design should have only one module, and you need to avoid multiple modules in a design.
You should strictly abide by my requirements and restrictions!
"""





compile_pass_count = 0
sim_pass_count = 0

def extract_total_design(text):
    # 定义正则表达式模式
    pattern = r'<total_design>(.*?)</total_design>'
    
    # 使用 re.findall 查找所有匹配的文本
    matches = re.findall(pattern, text, re.DOTALL)
    
    return matches

model = AutoModelForCausalLM.from_pretrained(
    model_path[0], 
    device_map="auto", 
    torch_dtype="auto",
    trust_remote_code=True  # 添加这个参数
)

model.eval()

for param in model.parameters():
    param.requires_grad = False
        
tokenizer = AutoTokenizer.from_pretrained(model_path[0])

#csv_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/data/rtl_data/RTLLM_pass_2.csv"
#csv_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/data/rtl_data/RTLLM_r1_generate_think_test_.csv"
#csv_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/data/rtl_data/RTLLM_r1_generate_think_test_nohead.csv"
#csv_path = "/mnt/sdb/workspace/zxy_workspace/GRPO/data/rtl_data/RTLLM_pass_nohead.csv"
#csv_path = "/mnt/sdb/workspace/zxy_workspace/Retest/regenerate/data/RTLLM_pass_nohead.csv"

length = []
total_token = []
generate_num = []
compile_pass_num = []
sim_pass_num = []



# 主目录路径
main_dir = "/mnt/sdb/workspace/zxy_workspace/GRPO/czt_workspace/cvdp"  # 当前工作目录，你可以换成其他路径

# 遍历主目录下的子文件夹
for root, dirs, files in os.walk(main_dir):
    # 只处理有 prompt.txt 和 rtl 文件夹的目录（即二级子文件夹）
    if 'prompt.txt' in files and 'rtl' in dirs:
        prompt_path = os.path.join(root, 'prompt.txt')
        rtl_path = os.path.join(root, 'rtl')

        # 读取 prompt.txt 内容
        with open(prompt_path, 'r', encoding='utf-8') as f:
            prompt_content = f.read()

        # 遍历模型名称
        model_name = os.path.basename(os.path.normpath(model_path[0]))
        model_dir = os.path.join(rtl_path, model_name)
        os.makedirs(model_dir, exist_ok=True)

        # 5次 prompt 内容为 放在不同文件夹
        for i in range(1, 6):
            file_dir = os.path.join(model_dir, f'{i}')
            os.makedirs(file_dir, exist_ok=True)
            
            messages = [
                {"role": "user", "content": CREATE_THINKING_PROMPT + prompt_content + SYNTAX_PROMPT + GENERATE_MORE_RTL_PROMPT}
                ]

            # 生成模型输入文本
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,   # 不进行分词
                add_generation_prompt=True,  # 添加生成提示
                #temperature = 0.7
            )
            model_inputs = tokenizer([text], return_tensors="pt",padding=True,truncation=True).to(model.device)
            attention_mask = model_inputs['attention_mask'].to(model.device)
            # 生成回复
            generated_ids = model.generate(
                model_inputs.input_ids,
                attention_mask=attention_mask,
                max_new_tokens=4096,
            )

            # 从中提取新生成的，过滤inputs
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)  
            ]

            response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

            res = extract_xml_answer(response)
            think = extract_think(response)

            token_num = count_tokens(think,tokenizer)
            total_token_num = count_tokens(response,tokenizer)
            length.append(token_num)
            total_token.append(total_token_num)

            answer = extract_xml_answer(response)

            diff_verilog = []
            diff_ast = []
            design_list = extract_total_design(response)
            diff_compile_pass_v = 0
            diff_sim_pass_v = 0
            for design in design_list:
                ast = validate_verilog_code(design)
                if ast != None:
                    if diff_ast == []:
                        diff_verilog.append(design)
                        diff_ast.append(ast)
                    else:
                        same = 0
                        for already_ast in diff_ast:
                            if ast_similarity(already_ast,ast) != 0:
                                same = 1
                        for already_verilog in diff_verilog:
                            if already_verilog == design:
                                same = 1
                        if same == 0:
                            diff_verilog.append(design)
                            diff_ast.append(ast)
                if ast == None:
                    diff_verilog.append(design)
            generate_num.append(len(diff_verilog))
            
            response_path = os.path.join(file_dir, "response.txt") 
            with open(response_path, "w") as f:
                    f.write(response)   

            for item in range(len(diff_verilog)):
                verilog_path = os.path.join(file_dir, f"{item}.txt")
                with open(verilog_path, "w") as f:
                    f.write(diff_verilog[item])
                    