import datetime
import math
import os
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["http_proxy"] = "http://localhost:7890"
os.environ["https_proxy"] = "http://localhost:7890"  # 添加代理
import pathlib
from functools import partial
import warnings
import traceback
import json

import pandas as pd
import torch.multiprocessing as mp
from joblib import Memory
from num2words import num2words
import numpy as np
from omegaconf import OmegaConf
from rich.console import Console
from torch.utils.data import DataLoader
from tqdm import tqdm

from configs import config
from utils import seed_everything
import datasets
from code_analysis import is_followup_more_complex

import time
import matplotlib.pyplot as plt
from PIL import Image
from rich.pretty import pretty_repr
from torch import Tensor

import inspect
import ast
from collections import defaultdict, deque
import gc
import psutil
import copy
import base64
from io import BytesIO

def pil_to_base64(image):
    if not isinstance(image, Image.Image):
        console.print(f"[red]错误：输入不是 PIL.Image，类型为 {type(image)}[/red]")
        return None
    try:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        console.print(f"[green]图像转换为 Base64，长度：{len(img_str)}[/green]")
        return img_str
    except Exception as e:
        console.print(f"[red]Base64 编码失败：{e}[/red]")
        return None
def report_memory(prefix=""):
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / 1024 / 1024
    print(f"💾 {prefix} Memory Usage: {mem:.2f} MB")

class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, pathlib.Path):
            return str(obj)
        elif isinstance(obj, set):
            return list(obj)
        elif isinstance(obj, (np.integer, np.floating)):
            return obj.item()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif hasattr(obj, "to_dict"):
            return obj.to_dict()
        elif isinstance(obj, Tensor):
            return obj.tolist()
        elif obj.__class__.__name__ == "ImagePatch":
            return str(obj)
        return super().default(obj)

mp.set_sharing_strategy('file_system')
queue_results = None

cache = Memory('cache/' if config.use_cache else None, verbose=0)
runs_dict = {}
seed_everything()
intermediate_variables_dict = {}
dependency_graph = {}
console = Console(highlight=False)

def save_intermediate_variable(name, value):
    intermediate_variables_dict[name] = value

def rich_escape(text):
    return str(text).replace('[', '\\[').replace(']', '\\]')

def topological_sort(dep_graph):
    indegree = defaultdict(int)
    graph = defaultdict(list)
    all_vars = set(dep_graph.keys()) | {v for deps in dep_graph.values() for v in deps}
    for var in all_vars:
        dep_graph.setdefault(var, [])
    for var, deps in dep_graph.items():
        for d in deps:
            graph[d].append(var)
            indegree[var] += 1
        if var not in indegree:
            indegree[var] = indegree.get(var, 0)
    queue = deque([n for n in indegree if indegree[n] == 0])
    sorted_order = []
    while queue:
        node = queue.popleft()
        sorted_order.append(node)
        for neighbor in graph[node]:
            indegree[neighbor] -= 1
            if indegree[neighbor] == 0:
                queue.append(neighbor)
    return sorted_order

def compute_depth_width_metrics_graph_based(dep_graph_old, dep_graph_new):
    def longest_path_length(dep_graph):
        memo = {}
        visited = set()
        def dfs(node):
            if node in memo:
                return memo[node]
            visited.add(node)
            max_depth = 0
            for neighbor in dep_graph.get(node, []):
                if neighbor not in visited:
                    max_depth = max(max_depth, dfs(neighbor))
            visited.remove(node)
            memo[node] = max_depth + 1
            return memo[node]
        return max(dfs(node) for node in dep_graph) if dep_graph else 0
    depth_old = longest_path_length(dep_graph_old)
    depth_new = longest_path_length(dep_graph_new)
    depth_increase = max(depth_new - depth_old, 0)
    width_increase = 0
    all_keys = set(dep_graph_new.keys())
    for var in all_keys:
        old_deps = set(dep_graph_old.get(var, []))
        new_deps = set(dep_graph_new.get(var, []))
        if len(new_deps) > len(old_deps):
            width_increase += 1
    total = width_increase + depth_increase if (width_increase + depth_increase) > 0 else 1
    width_coef = round(width_increase / total, 2)
    depth_coef = round(depth_increase / total, 2)
    return {
        "width_increase": width_increase,
        "depth_increase": depth_increase,
        "depth_old": depth_old,
        "depth_new": depth_new,
        "width_coef": width_coef,
        "depth_coef": depth_coef,
    }

def find_dependencies_from_source(source_code):
    dependencies = {}
    try:
        tree = ast.parse(source_code)
        for node in ast.walk(tree):
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        target_var = target.id
                        if target_var not in dependencies:
                            dependencies[target_var] = set()
                        for child in ast.walk(node.value):
                            if isinstance(child, ast.Name):
                                dependencies[target_var].add(child.id)
                            elif isinstance(child, ast.Subscript):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add(child.value.id)
                            elif isinstance(child, ast.Attribute):
                                if isinstance(child.value, ast.Name):
                                    dependencies[target_var].add(child.value.id)
                            elif isinstance(child, ast.Call):
                                if isinstance(child.func, ast.Name):
                                    dependencies[target_var].add(child.func.id)
            elif isinstance(node, ast.FunctionDef):
                for arg in node.args.args:
                    if arg.arg not in dependencies:
                        dependencies[arg.arg] = set()
    except Exception as e:
        print(f"Error analyzing source code: {e}")
    for var, dep_set in dependencies.items():
        dependencies[var] = list(dep_set)
    return dependencies

def extract_dependencies_from_function(source_code):
    global dependency_graph
    dependencies = find_dependencies_from_source(source_code)
    dependency_graph.update(dependencies)

def show_all(lineno, value, valuename=None, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None, lastlineno=[-1]):
    time.sleep(0.1)
    global intermediate_variables_dict, dependency_graph
    if console_in is None:
        from rich.console import Console
        console_in = Console(highlight=False)
    if valuename is None:
        try:
            frame = inspect.currentframe().f_back
            local_vars = frame.f_locals
            valuename = next((k for k, v in local_vars.items() if v is value), 'unknown')
        except Exception:
            valuename = "unknown"
    if lineno is not None and lineno != lastlineno[0]:
        console_in.rule(f"[bold]Line {lineno}[/bold]", style="chartreuse2")
        lastlineno[0] = lineno
    if isinstance(value, Image.Image):
        intermediate_variables_dict[valuename] = "Image"
    elif isinstance(value, list):
        for i, item in enumerate(value):
            show_all(None, item, f"{valuename}[{i}]", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"List of len {len(value)}"
    elif isinstance(value, dict):
        for k, v in value.items():
            show_all(None, v, f"{valuename}['{k}']", fig=fig, disp=disp, usefig=usefig)
        intermediate_variables_dict[valuename] = f"Dict of len {len(value)}"
    else:
        intermediate_variables_dict[valuename] = value

def split_codex_output(text):
    original_code = text.strip()
    if original_code.startswith("Original Code:"):
        original_code = original_code.split("Original Code:", 1)[1].strip()
    if original_code.startswith("<code>"):
        original_code = original_code.split("<code>", 1)[1].strip()
    return original_code, "", ""

def my_collate(batch):
    to_return = {k: [d[k] for d in batch] for k in batch[0].keys()}
    return to_return

def run_program(parameters, queues_in_, input_type_, retrying=False):
    from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno, coerce_to_numeric, process_guesses
    from video_segment import VideoSegment
    from inspect import currentframe
    global queue_results, intermediate_variables_dict, dependency_graph

    code, sample_id, image, query = parameters
    code_header = f'def execute_command_{sample_id}(' \
                  f'{input_type_}, query, ' \
                  f'ImagePatch, VideoSegment, ' \
                  'llm_query, bool_to_yesno, distance, best_image_match):\n' \
                  f'    # Answer is:'
    code = code_header + code
    print(f'Running sample {sample_id} with code: {code}')

    try:
        exec(compile(code, 'Codex', 'exec'), globals())
    except Exception as e:
        print(f'Sample {sample_id} failed at compilation time with error: {e}')
        try:
            with open(config.fixed_code_file, 'r') as f:
                fixed_code = f.read()
            code = code_header + fixed_code
            exec(compile(code, 'Codex', 'exec'), globals())
        except Exception as e2:
            print(f'Not even the fixed code worked. Sample {sample_id} failed at compilation time with error: {e2}')
            return None, code, None, None

    queues = [queues_in_, queue_results]
    image_patch_partial = partial(ImagePatch, queues=queues)
    video_segment_partial = partial(VideoSegment, queues=queues)
    llm_query_partial = partial(llm_query, queues=queues)

    try:
        result = globals()[f'execute_command_{sample_id}'](
            image, query,
            image_patch_partial, video_segment_partial,
            llm_query_partial, bool_to_yesno, distance, best_image_match)
    except Exception as e:
        traceback.print_exc()
        if retrying:
            return None, code, None, None
        print(f'Sample {sample_id} failed with error: {e}. Retrying with fixed code.')
        new_code = "["  # 故意触发错误以进入重试逻辑
        result = run_program((new_code, sample_id, image, query), queues_in_, input_type_,
                             retrying=True)[0]
        return result, code, None, None

    if f'execute_command_{sample_id}' in globals():
        del globals()[f'execute_command_{sample_id}']
    print(f'Finished sample {sample_id} with result: {result}')

    intermediate_variables_temp = intermediate_variables_dict.copy()
    intermediate_variables_dict.clear()
    dependency_graph_temp = find_dependencies_from_source(code)
    return result, code, intermediate_variables_temp, dependency_graph_temp

def worker_init(queue_results_):
    global queue_results
    index_queue = mp.current_process()._identity[0] % len(queue_results_)
    queue_results = queue_results_[index_queue]

def main():
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass

    from vision_processes import queues_in, finish_all_consumers, forward, manager
    from datasets import get_dataset

    batch_size = 1
    num_processes = min(batch_size, 50)

    if config.multiprocessing:
        queue_results_main = manager.Queue()
        queues_results = [manager.Queue() for _ in range(batch_size)]
    else:
        queue_results_main = None
        queues_results = [None for _ in range(batch_size)]

    model_name_codex = 'codellama' if config.codex.model == 'codellama' else 'codex'
    codex = partial(forward, model_name=model_name_codex, queues=[queues_in, queue_results_main])

    if config.clear_cache:
        cache.clear()

    if config.wandb:
        import wandb
        wandb.init(project="viper", config=OmegaConf.to_container(config))
        wandb.save(config.codex.prompt)

    dataset = get_dataset(config.dataset)

    with open(config.codex.prompt) as f:
        base_prompt = f.read().strip()

    codes_all = None
    if config.use_cached_codex:
        results = pd.read_csv(config.cached_codex_path)
        codes_all = [r.split('# Answer is:')[1] for r in results['code']]

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,
                            collate_fn=my_collate)
    input_type = dataset.input_type

    all_results = []
    all_answers = []
    all_codes = []
    all_ids = []
    all_queries = []
    all_expanded_questions = []
    all_expanded_codes = []
    all_expanded_results = []
    all_keep = []
    all_halstead_o = []
    all_halstead_f = []
    all_cyclomatic_complexity_o = []
    all_cyclomatic_complexity_f = []
    all_codebert_similarity = []
    all_intermediate_variables_dict_original = []
    all_intermediate_variables_dict_expanded = []
    all_dependency_graph_original = []
    all_dependency_graph_expanded = []
    all_topological_sort_original = []
    all_topological_sort_expanded = []
    all_expansion_index = []
    all_question_ids = []
    all_original_img_paths = []
    all_followup_img_paths = []
    with mp.Pool(processes=num_processes, initializer=worker_init, initargs=(queues_results,)) \
            if config.multiprocessing else open(os.devnull, "w") as pool:
        try:
            n_batches = len(dataloader)
            half_batches = (n_batches + 1) // 2  # 向上取整以确保至少覆盖一半样本
            for i, batch in tqdm(enumerate(dataloader), total=half_batches):
                gc.collect()

                # 分组 original 和 expanded 样本
                original_indices = [idx for idx, src in enumerate(batch['source']) if src == 'original']
                expanded_indices = [idx for idx, src in enumerate(batch['source']) if src == 'expanded']

                original_queries = [batch['query'][idx] for idx in original_indices]
                original_images = [batch['image'][idx] for idx in original_indices]
                original_sample_ids = [batch['sample_id'][idx] for idx in original_indices]
                original_answers = [batch['answer'][idx] for idx in original_indices]
                original_question_ids = [batch['question_id'][idx] for idx in original_indices]  # 新增：收集 question_id
                original_img_paths = [batch['image_name'][idx] if batch['image_name'][idx] else "" for idx in
                                      original_indices]

                expanded_queries = [batch['query'][idx] for idx in expanded_indices]
                expanded_images = [batch['image'][idx] for idx in expanded_indices]
                expanded_sample_ids = [batch['sample_id'][idx] for idx in expanded_indices]
                expanded_img_paths = [batch['image_name'][idx] if batch['image_name'][idx] else "" for idx in
                                      expanded_indices]
                # 生成代码
                original_codes = []
                expanded_codes = []
                if not config.use_cached_codex:
                    if original_queries:
                        # 构造包含文本和图像的输入
                        original_inputs = [
                            {
                                "query": query,
                                "image": pil_to_base64(img),
                                "image_name": img_path
                            }
                            for query, img, img_path in zip(original_queries, original_images, original_img_paths)
                        ]
                        original_codes = codex(prompt=original_inputs, base_prompt=base_prompt, input_type=input_type)
                        original_codes = [split_codex_output(code)[0] for code in original_codes]
                    if expanded_queries:
                        expanded_inputs = [
                            {
                                "query": query,
                                "image": pil_to_base64(img),
                                "image_name": img_path
                            }
                            for query, img, img_path in zip(expanded_queries, expanded_images, expanded_img_paths)
                        ]
                        expanded_codes = codex(prompt=expanded_inputs, base_prompt=base_prompt, input_type=input_type)
                        expanded_codes = [split_codex_output(code)[0] for code in expanded_codes]
                else:
                    codes = codes_all[i * len(original_queries):(i + 1) * len(original_queries)]
                    original_codes = [split_codex_output(code)[0] for code in codes]
                    expanded_codes = original_codes  # 如果缓存中没有 expanded，需重新生成

                console.print(f'Batch {i+1}/{n_batches} - Codex generated')
                console.print(f'Original queries: {original_queries}')
                console.print(f'Original codes: {original_codes}')
                console.print(f'Expanded queries: {expanded_queries}')
                console.print(f'Expanded codes: {expanded_codes}')

                original_results = []
                expanded_results = []
                if config.execute_code:
                    print("yes")
                    if not config.multiprocessing:
                        # 处理 original 样本
                        for oc, sample_id, img, query in \
                                zip(original_codes, original_sample_ids, original_images, original_queries):
                            result_o = run_program([oc, sample_id, img, query], queues_in, input_type)
                            print(f"Original result: {result_o[0]}\nIntermediate variables: {result_o[2]}\n"
                                  f"Dependency graph: {json.dumps(result_o[3], indent=2)}")
                            original_results.append(result_o)

                        # 处理 expanded 样本
                        for ec, sample_id, img, query in \
                                zip(expanded_codes, expanded_sample_ids, expanded_images, expanded_queries):
                            result_e = run_program([ec, sample_id, img, query], queues_in, input_type)
                            print(f"Expanded result: {result_e[0]}\nIntermediate variables: {result_e[2]}\n"
                                  f"Dependency graph: {json.dumps(result_e[3], indent=2)}")
                            expanded_results.append(result_e)
                    else:
                        # 多进程处理 original 样本
                        original_results = list(pool.imap(partial(
                            run_program, queues_in_=queues_in, input_type_=input_type),
                            zip(original_codes, original_sample_ids, original_images, original_queries)))
                        for result_o in original_results:
                            print(f"Original result: {result_o[0]}\nIntermediate variables: {result_o[2]}\n"
                                  f"Dependency graph: {json.dumps(result_o[3], indent=2)}")

                        # 多进程处理 expanded 样本
                        expanded_results = list(pool.imap(partial(
                            run_program, queues_in_=queues_in, input_type_=input_type),
                            zip(expanded_codes, expanded_sample_ids, expanded_images, expanded_queries)))
                        for result_e in expanded_results:
                            print(f"Expanded result: {result_e[0]}\nIntermediate variables: {result_e[2]}\n"
                                  f"Dependency graph: {json.dumps(result_e[3], indent=2)}")
                else:
                    original_results = [(None, c, None, find_dependencies_from_source(c)) for c in original_codes]
                    expanded_results = [(None, c, None, find_dependencies_from_source(c)) for c in expanded_codes]
                    warnings.warn("Not executing code! Set 'execute_code' to True to run it.")

                # 计算扩展指数
                for result_o, result_e, oc, ec in zip(original_results, expanded_results, original_codes, expanded_codes):
                    comparison = compute_depth_width_metrics_graph_based(result_o[3], result_e[3])
                    all_expansion_index.append(comparison)
                    print(f"Expansion index: {json.dumps(comparison, indent=2)}")

                all_results += [r[0] for r in original_results]
                all_codes += [r[1] for r in original_results]
                all_ids += original_sample_ids
                all_answers += original_answers
                all_queries += original_queries
                all_intermediate_variables_dict_original += [r[2] for r in original_results]
                all_dependency_graph_original += [r[3] for r in original_results]
                all_topological_sort_original += [topological_sort(r[3]) for r in original_results]
                all_question_ids += original_question_ids
                all_original_img_paths += original_img_paths
                all_expanded_questions += expanded_queries
                all_expanded_codes += [r[1] for r in expanded_results]
                all_expanded_results += [r[0] for r in expanded_results]
                all_followup_img_paths += expanded_img_paths
                all_intermediate_variables_dict_expanded += [r[2] for r in expanded_results]
                all_dependency_graph_expanded += [r[3] for r in expanded_results]
                all_topological_sort_expanded += [topological_sort(r[3]) for r in expanded_results]
                console.print(f"[yellow]暂停 5 秒以便检查批次 {i + 1}/{n_batches} 的输出...[/yellow]")
                time.sleep(5)
        except Exception as e:
            traceback.print_exc()
            console.print(f'Exception: {e}')
            console.print("Completing logging and exiting...")

        try:
            console.print(f'Final accuracy: {accuracy}')
        except Exception as e:
            print(f'Error computing accuracy: {e}')

        if config.followup:
            for oc, ec in zip(all_codes, all_expanded_codes):
                keep, details = is_followup_more_complex(oc, ec)
                all_keep.append(keep)
                all_halstead_o.append(details.get("halstead_o", -1))
                all_halstead_f.append(details.get("halstead_f", -1))
                all_cyclomatic_complexity_o.append(details.get("cc_o", -1))
                all_cyclomatic_complexity_f.append(details.get("cc_f", -1))
                all_codebert_similarity.append(details.get("similarity", -1))

        if config.save:
            results_dir = pathlib.Path(config['results_dir']) / config.dataset.split
            results_dir.mkdir(parents=True, exist_ok=True)

            if not config.save_new_results:
                filename = 'results.csv'
            else:
                existing_files = list(results_dir.glob('results_*.csv'))
                if len(existing_files) == 0:
                    filename = 'results_0.csv'
                else:
                    filename = 'results_' + str(max([int(ef.stem.split('_')[-1]) for ef in existing_files if
                                                     str.isnumeric(ef.stem.split('_')[-1])]) + 1) + '.csv'

            df = pd.DataFrame({
                'id': all_ids,
                'question_id': all_question_ids,  # 新增：保存 question_id
                'query': all_queries,
                'original_img_path': all_original_img_paths,  # 修改：保存 original 图像路径
                'answer': all_answers,
                'original_code': all_codes,
                'original_result': all_results,
                'intermediate_variables_original': all_intermediate_variables_dict_original,
                'dependency_graph_original': all_dependency_graph_original,
                'topological_sort_original': all_topological_sort_original
            })
            df['original_result'] = df['original_result'].apply(str)
            df['intermediate_variables_original'] = df['intermediate_variables_original'].apply(
                lambda x: json.dumps(x, cls=CustomJSONEncoder))

            if config.followup:
                df['followup_question'] = all_expanded_questions
                df['followup_code'] = all_expanded_codes
                df['followup_result'] = all_expanded_results
                df['followup_img_path'] = all_followup_img_paths  # 新增：保存 followup 图像路径
                df['halstead_original'] = all_halstead_o
                df['halstead_followup'] = all_halstead_f
                df['cyclomatic_complexity_original'] = all_cyclomatic_complexity_o
                df['cyclomatic_complexity_followup'] = all_cyclomatic_complexity_f
                df['codebert_similarity'] = all_codebert_similarity
                df['whether_to_keep'] = all_keep
                df['intermediate_variables_followup'] = all_intermediate_variables_dict_expanded
                df['dependency_graph_followup'] = all_dependency_graph_expanded
                df['topological_sort_followup'] = all_topological_sort_expanded

                df['intermediate_variables_followup'] = df['intermediate_variables_followup'].apply(
                    lambda x: json.dumps(x, cls=CustomJSONEncoder))
                df['followup_result'] = df['followup_result'].apply(str)

            df.to_csv(results_dir / filename, header=True, index=False, encoding='utf-8')
            json_data = df.to_dict(orient='records')
            with open(results_dir / (filename.replace('.csv', '.json')), 'w', encoding='utf-8') as f:
                json.dump(json_data, f, ensure_ascii=False, indent=2, cls=CustomJSONEncoder)


        finish_all_consumers()

def test_command(image):
    class ImagePatch:
        def __init__(self, img):
            self.img = img
            self.width = img[2] - img[0]
            self.height = img[3] - img[1]
            self.vertical_center = (img[1] + img[3]) / 2

        def find(self, object_name):
            return [self]

        def verify_property(self, object_name, prop):
            return True

    total_size = 0
    large_count = 0
    for img in image:
        patch = ImagePatch(img)
        objects = patch.find("object")
        show_all(None, objects, "objects")
        for obj in objects:
            if obj.verify_property("object", "large"):
                large_count += 1
                size = obj.width * obj.height
                show_all(None, size, "size")
                total_size += size
    show_all(None, total_size, "total_size")
    show_all(None, large_count, "large_count")
    average_size = total_size / large_count if large_count > 0 else 0
    show_all(None, average_size, "average_size")
    return average_size

if __name__ == '__main__':
    main()