import argparse
from collections import defaultdict
import datetime
import json
import math
import os
import random
import string
import threading
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed, TimeoutError
from statistics import stdev
import traceback
import numpy as np
import time
from datasets import load_dataset


from utils.common_utils import load_json_file
from utils.docker_utils import setup_logger, copy_src_files
from utils.evo_utils import load_dgm_metadata
import MCTS_utils
# from MCTS_utils import Node
from global_tree_search import Node

def update_metadata(output_dir, n_task_evals):
        with open(os.path.join(output_dir, "dgm_metadata.jsonl"), "a") as f:
            f.write(json.dumps({
                "n_task_evals": n_task_evals,
                'nodes': [node.save_as_dict() for node in MCTS_utils.nodes.values() if node.commit_id != 'initial']
            }, indent=2) + "\n")
        json.dump(MCTS_utils.init_evaluated_tasks, open(os.path.join(output_dir, "init_evaluated_tasks.json"), "w"))

def initialize_run(output_dir, self_improve_llm, tasks, initial_agent_name, prevrun_dir=None, polyglot=False, timeout=3600):

    # Initialize archive
    n_task_evals = 0
        
    # Copy cached initial version into experiment dir
    initial_folder = 'initial_swe/' if not polyglot else 'initial_polyglot/'
    if not prevrun_dir:
        if os.path.exists(f"{initial_folder}/{initial_agent_name}"):
            os.system(f"cp -r {initial_folder}/{initial_agent_name} {output_dir}/initial")
        else:
            raise RuntimeError("Error: Need to properly configure evaluation results for the initial version.")

    root = Node(commit_id='initial')
    # metadata = load_json_file(os.path.join(initial_folder, initial_agent_name, "metadata.json"))
    # root.utility_measures = [1] * metadata['overall_performance']['total_resolved_instances'] + \
    #                 [0] * (metadata['overall_performance']['total_submitted_instances'] - metadata['overall_performance']['total_resolved_instances'])
    if prevrun_dir:
        # Load previous run's archive
        MCTS_utils.init_evaluated_tasks = load_json_file(os.path.join(prevrun_dir, "init_evaluated_tasks.json"))
        metadata_path = os.path.join(prevrun_dir, "dgm_metadata.jsonl")
        metadata = load_dgm_metadata(metadata_path, last_only=True)
        for node in metadata['nodes']:
            commit_id = node['commit_id']
            # utility_measures = node['utility_measures']
            parent_id = node['parent_id']
            Node(commit_id, parent_id=parent_id, id=node['id'])
        for node in MCTS_utils.nodes.values():
            if node.parent_id is not None:
                parent = MCTS_utils.nodes[node.parent_id]
                parent.add_child(node)

    submitted_ids = defaultdict(set)  # node_id -> set of submitted task ids
    for node in MCTS_utils.nodes.values():
        metadata = load_json_file(os.path.join(output_dir, node.commit_id, "metadata.json"))
        submitted_ids[node.id] = set(metadata['overall_performance']['total_submitted_ids'])
        node.utility_measures = [1 for _ in range(metadata['overall_performance']['total_resolved_instances'])] + \
                    [0 for _ in range(metadata['overall_performance']['total_submitted_instances'] - metadata['overall_performance']['total_resolved_instances'])]
        if node.commit_id != 'initial':
            n_task_evals += metadata['overall_performance']['total_submitted_instances']
    MCTS_utils.init(polyglot, output_dir, tasks, n_task_evals, self_improve_llm, timeout)
    return os.path.join(initial_folder, initial_agent_name, 'src'), submitted_ids

def main():
    parser = argparse.ArgumentParser(description="Optimistic Tree Search")
    parser.add_argument("--max_task_evals", type=int, default=800, help="Maximum number of evolution iterations.")
    parser.add_argument("--max_workers", type=int, default=10, help="Number of parallel workers for self-improvement attempts.")
    parser.add_argument("--continue_from", type=str, default=None, help="Directory to continue the run from.")
    parser.add_argument("--polyglot", default=False, action='store_true', help="Run single shallow evaluation for self-improvement on swe.")
    parser.add_argument("--self_improve_llm", default='gpt-5-nano', type=str, help='LLM model to use for self-improvement')
    parser.add_argument("--downstream_llm", default='gpt-5-nano', type=str, help='LLM model to use for downstream tasks')
    parser.add_argument("--diagnose_llm", default='gpt-5-nano', type=str, help='LLM model to use for diagnosis')
    parser.add_argument("--alpha", type=float, default=0.7, help="Alpha parameter for node expansion.")
    parser.add_argument('--cool_down', default=False, action='store_true', help='whether to use a decreasing temperature over iterations')
    parser.add_argument("--beta", type=float, default=1, help="cooling down factor beta.")
    parser.add_argument("--no_full_eval", default=False, action='store_true', help="Do not run full evaluation on swe if a node is the top N highest performing.")
    parser.add_argument("--self_improve_timeout", type=int, default=3600, help="Timeout for self-improvement attempts.")
    parser.add_argument("--evaluation_timeout", type=int, default=1800, help="Timeout for evaluation attempts.")
    parser.add_argument("--n_pseudo_descendant_evals", type=int, default=10, help="Number of pseudo descendant evaluations.")
    parser.add_argument("--eval_random_level", type=float, default=0.5, help="Randomness level for evaluation task selection.")
    parser.add_argument("--initial_agent_name", required=True, help="Name of the initial agent.")

    args = parser.parse_args()

    # Variables for this DGM run
    if not args.continue_from:
        run_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S_%f")
    else:
        run_id = os.path.basename(args.continue_from)
        
    output_dir = os.path.abspath(os.path.join("./output_mcts", run_id))
    
    # Ensure output directory exists and log path info
    os.makedirs(output_dir, exist_ok=True)
    print(f"Working directory: {os.getcwd()}")
    print(f"Output directory: {output_dir}")
    print(f"Output directory exists: {os.path.exists(output_dir)}")
    
    import self_improve_step
    import polyglot.harness
    polyglot.harness.llm = args.downstream_llm  # Set the LLM model for downstream tasks
    import swe_bench.harness
    swe_bench.harness.llm = args.downstream_llm  # Set the LLM model for downstream tasks
    polyglot.harness.timeout = args.evaluation_timeout
    swe_bench.harness.timeout = args.evaluation_timeout
    self_improve_step.diagnose_llm = args.diagnose_llm
    # Initialize logger early
    logger = setup_logger(os.path.join(output_dir, "mcts_outer.log"))
    # SWE issues to consider
    if not args.polyglot:
        tasks = load_json_file("./swe_bench/subsets/medium.json") + load_json_file("./swe_bench/subsets/small.json")
        if not args.no_full_eval:
            tasks = [task['instance_id'] for task in load_dataset("princeton-nlp/SWE-bench_Verified")['test']]
        random.seed(42)
        random.shuffle(tasks)
    else:
        tasks = load_json_file("./polyglot/subsets/medium.json") + load_json_file("./polyglot/subsets/small.json")

    src_path, submitted_ids = initialize_run(output_dir, args.self_improve_llm, tasks, args.initial_agent_name, prevrun_dir=args.continue_from, polyglot=args.polyglot, timeout=args.self_improve_timeout)
    total_num_tasks = len(MCTS_utils.total_tasks)

    # Set up logger
    logger.info(f"Starting HOO run {run_id} with arguments: {vars(args)}")
  
    def TS_sample(evals):
        alphas = [1e-2 + np.sum(de) for de in evals]
        betas = [1e-2 + len(de) - np.sum(de) for de in evals]
        if args.cool_down:
            alphas = np.array(alphas) * (10000 if args.max_task_evals == MCTS_utils.n_task_evals else
                                        args.max_task_evals ** args.beta / (args.max_task_evals - MCTS_utils.n_task_evals) ** args.beta)
            betas = np.array(betas) * (10000 if args.max_task_evals == MCTS_utils.n_task_evals else
                                        args.max_task_evals ** args.beta / (args.max_task_evals - MCTS_utils.n_task_evals) ** args.beta)
        thetas = np.random.beta(alphas, betas)
        return np.argmax(thetas)

    n_pending_expands = 0
    # n_pending_expands_lock = threading.Lock()
    n_pending_measures = 0
    # n_pending_measures_lock = threading.Lock()
    lock = threading.Lock()

    def expand():     
        with lock:
            nodes = MCTS_utils.nodes[0].get_sub_tree(fn=lambda node: node)
            nodes = [node for node in nodes if np.isfinite(node.mean_utility) and node.mean_utility > 0]
            decendant_evals = [node.get_decendant_evals(num_pseudo=args.n_pseudo_descendant_evals) for node in nodes]
            selected_node = nodes[TS_sample(decendant_evals)]
        child_commit = MCTS_utils.sample_child(selected_node.commit_id, image_name=args.initial_agent_name + ':latest')
        with lock:
            if child_commit != 'failed':
                selected_node.children.append(Node(child_commit, parent_id=selected_node.id))
                update_metadata(output_dir, MCTS_utils.n_task_evals)

    def sample():
        time.sleep(random.random())
        with lock:
            nonlocal n_pending_expands, n_pending_measures
            if MCTS_utils.n_task_evals >= args.max_task_evals:
                return
            
            if MCTS_utils.n_task_evals ** args.alpha >= len(MCTS_utils.nodes) - 1 + n_pending_expands:
                n_pending_expands += 1
                is_expand = True
            else:
                is_expand = False
        if is_expand:
            expand()
            with lock:
                n_pending_expands -= 1
                return  
        
        with lock:
            nodes = MCTS_utils.nodes[0].get_sub_tree(fn=lambda node: node)
            nodes = [node for node in nodes if len(submitted_ids[node.id]) < total_num_tasks]
            evals = [node.utility_measures for node in nodes]
            if len(evals) == 0:
                return
            selected_node = nodes[TS_sample(evals)]
            available_tasks = list([task for task in MCTS_utils.total_tasks if task not in submitted_ids[selected_node.id]])
            if len(available_tasks) == 0:
                return
            if random.random() < args.eval_random_level:
                selected_node_tasks = random.choice(available_tasks)
            else:
                selected_node_tasks = available_tasks[0]
            submitted_ids[selected_node.id].add(selected_node_tasks)
            n_pending_measures += 1
        evals = MCTS_utils.eval_agent(selected_node.commit_id, tasks=[selected_node_tasks], init_agent_path=src_path)
        with lock:
            selected_node.utility_measures += evals
            n_pending_measures -= 1
            update_metadata(output_dir, MCTS_utils.n_task_evals)

    try:
        with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
            futures = [executor.submit(expand)
                for _ in range(len(MCTS_utils.nodes) - 1, min(5, int(args.max_workers ** args.alpha)))]
            for future in as_completed(futures):
                future.result()

        with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
            futures = [executor.submit(sample)
                for _ in range(int(args.max_task_evals * 100))]
            for future in as_completed(futures):
                future.result()


    except Exception as e:
        logger.error(f"Error: {e}")
        logger.error(traceback.format_exc())
        print(repr(e))
        import pdb; pdb.set_trace()

if __name__ == "__main__":
    main()