import asyncio
import argparse
import json
from pathlib import Path
import traceback
from metagpt.logs import get_logger
import datetime
from dotenv import load_dotenv
import multiprocessing
from src.environment import MultiAgentEnvironment
from src.utils import setup_output_directory, load_config, distribute_configs
from filelock import FileLock
from contextlib import contextmanager

# Set the start method to 'spawn' for all platforms
multiprocessing.set_start_method('spawn', force=True)

# Global flag for graceful shutdown
shutdown_flag = multiprocessing.Event()

def signal_handler(signum, frame):
    print("Received termination signal. Initiating graceful shutdown...")
    shutdown_flag.set()

def safe_write_to_file(file_path, content):
    lock_path = f"{file_path}.lock"
    with FileLock(lock_path):
        with open(file_path, 'a') as f:
            json.dump(content, f)
            f.write('\n')

@contextmanager
def managed_process_pool(processes):
    pool = multiprocessing.Pool(processes=processes)
    try:
        yield pool
    finally:
        pool.close()
        pool.join()

def process_task_set(task_set, num_agents, agent_llm_config, assistant_config, max_rounds, output_dir):
    if shutdown_flag.is_set():
        return {'success': False, 'set_name': task_set['set_name'], 'error': 'Task cancelled due to shutdown'}

    logger = get_logger(f"evolving_tasks_{task_set['set_name']}", output_dir)
    results_file = Path(output_dir) / 'evolving_tasks_results.jsonl'

    async def run_task_set():
        try:
            env = MultiAgentEnvironment(
                num_agents=num_agents,
                agent_llm_config=agent_llm_config,
                assistant_llm_config=assistant_config,
                max_rounds=max_rounds,
                output_dir=output_dir
            )
            set_name = task_set['set_name']
            results = await env.run_evolving_tasks(task_set)

            for result in results:
                result['set_name'] = set_name
                safe_write_to_file(results_file, result)

            logger.info(f"Results for task set '{set_name}' appended to {results_file}")
            return {'success': True, 'set_name': set_name}

        except Exception as e:
            logger.error(f"Task set '{task_set['set_name']}' failed: {str(e)}")
            logger.error(f"Traceback: {traceback.format_exc()}")
            return {'success': False, 'set_name': task_set['set_name'], 'error': str(e)}

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    return loop.run_until_complete(run_task_set())

def main():
    load_dotenv()
    parser = argparse.ArgumentParser(description="Evolving Task Multi-Agent System")
    parser.add_argument('--config', type=str, default='config/default.yaml', help='Path to YAML configuration file')
    parser.add_argument('--task_name', type=str, default='level1', choices=['profile', 'output', 'level1', 'level2'], help='Task name')
    parser.add_argument('--num_agents', type=int, default=3, help='Number of agents in the environment')
    parser.add_argument('--max_rounds', type=int, default=5, help='Maximum number of interaction rounds per task')
    parser.add_argument('--processes', type=int, default=1, help='Number of processes to use')
    parser.add_argument('--failure_prob', type=float, default=0.0, help='Probability of agent failure')
    args = parser.parse_args()

    timestamp_str = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    output_dir = setup_output_directory(args, timestamp_str)
    logger = get_logger("evolving_tasks_main", output_dir)

    config = load_config(args.config)
    agent_config = config.get('agent', {})
    agents_config = config.get('agents', [])
    assistant_config = config.get('assistant', {})

    if agents_config:
        agent_llm_config = distribute_configs(agents_config, args.num_agents)
    else:
        agent_llm_config = agent_config

    mapping = {
        "profile": "datasets/evolving_task/profile.jsonl",
        "output": "datasets/evolving_task/output.jsonl",
        "level1": "datasets/evolving_task/level1.jsonl",
        "level2": "datasets/evolving_task/level2.jsonl",
    }

    failed_tasks = [
  "Diverse Task Set 47"
]
    with open(mapping[args.task_name], 'r') as file:
        task_sets = [json.loads(line) for line in file]
        if args.task_name == "level2":
            task_sets = [task for task in task_sets if task['set_name'] in failed_tasks]
        elif args.task_name == "level1":
            task_sets = [task for task in task_sets if task['set_name'] in [
  "Diverse Task Set 43"
]]

    with managed_process_pool(args.processes) as pool:
        results = []
        for task_set in task_sets:
            if shutdown_flag.is_set():
                logger.info("Shutdown flag set. Stopping task processing.")
                break
            result = pool.apply_async(process_task_set, (task_set, args.num_agents, agent_llm_config, assistant_config, args.max_rounds, output_dir))
            results.append(result)

        # Wait for all tasks to complete
        results = [r.get() for r in results]

    failed_sets = [result['set_name'] for result in results if not result['success']]

    if failed_sets:
        with open(Path(output_dir) / 'failed_sets.json', 'w') as f:
            json.dump(failed_sets, f, indent=2)
        logger.info(f"Saved {len(failed_sets)} failed task sets to failed_sets.json")

    logger.info("All evolving tasks completed.")

if __name__ == '__main__':
    main()