import gc
import multiprocessing
import asyncio
import argparse
import os
import datetime
import json
import traceback
import signal
from pathlib import Path
from dotenv import load_dotenv
from filelock import FileLock
import torch
from contextlib import contextmanager

from src.environment import MultiAgentEnvironment
from src.utils import setup_output_directory, save_config, load_config, distribute_configs, CostTracker
from src.dataset import CodingDataset, MathDataset, MultiChoiceDataset
from metagpt.logs import get_logger

# Set the start method to 'spawn' for all platforms
multiprocessing.set_start_method('spawn', force=True)

# Global flag for graceful shutdown
shutdown_flag = multiprocessing.Event()

def signal_handler(signum, frame):
    print("Received termination signal. Initiating graceful shutdown...")
    shutdown_flag.set()

# # Register the signal handler
# signal.signal(signal.SIGINT, signal_handler)
# signal.signal(signal.SIGTERM, signal_handler)

def clear_gpu_memory():
    if torch.cuda.is_available():
        for obj in gc.get_objects():
            if torch.is_tensor(obj):
                del obj
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def parse_arguments():
    parser = argparse.ArgumentParser(description="Multi-Agent System for Collaborative Task Solving")
    parser.add_argument('--task_name', type=str, default='coding', choices=['coding', 'math', 'multi_choice'])
    parser.add_argument('--config', type=str, default='config/default.yaml', help='Path to YAML configuration file')
    parser.add_argument('--num_agents', type=int, default=3, help='Number of agents in the environment')
    parser.add_argument('--max_rounds', type=int, default=5, help='Maximum number of interaction rounds')
    parser.add_argument('--threshold', type=float, default=0.1, help='Minimum performance threshold for agent profiles during warmup')
    parser.add_argument('--metrics', type=eval, nargs=3, default=[True, True, True], 
                    help='Three boolean values indicating which metrics are used: '
                         '[clarity, differentiation, alignment]')
    parser.add_argument('--debug', action='store_true', help='Debug mode')
    parser.add_argument('--processes', type=int, default=1, help='Number of processes to use')
    parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA')
    parser.add_argument('--failure_prob', type=float, default=0.0, help='Probability of agent failure')
    return parser.parse_args()

def get_dataset(args, results_dir):
    if args.task_name == "coding":
        return CodingDataset("datasets/bigcodebench/sampled_bigcodebench_dataset.jsonl", args)
    elif args.task_name == "math":
        return MathDataset("datasets/math/sampled_math.jsonl", args)
    elif args.task_name == "multi_choice":
        return MultiChoiceDataset("datasets/bigbenchhard/multi_choice/sampled_bigbenchhard.jsonl", args)
    else:
        raise ValueError(f"Unknown task name: {args.task_name}")

def safe_write_to_file(file_path, content):
    lock_path = f"{file_path}.lock"
    with FileLock(lock_path):
        with open(file_path, 'a') as f:
            json.dump(content, f)
            f.write('\n')
            
@contextmanager
def managed_process_pool(processes):
    pool = multiprocessing.Pool(processes=processes)
    try:
        yield pool
    finally:
        pool.close()
        pool.join()

def process_task(task, args, agent_llm_config, assistant_config, output_dir):
    if shutdown_flag.is_set():
        return {'success': False, 'task_id': task['task_id'], 'error': 'Task cancelled due to shutdown'}
    
    logger = get_logger(f"multi_agent_{task['task_id']}", output_dir, args.debug)

    if args.debug: 
        args.max_rounds = 1
        args.threshold = 0.0
    results_file = Path(setup_output_directory(args)) / 'results.jsonl'
    costs_file = Path(output_dir) / 'costs.jsonl'
    error_file = Path(output_dir) / 'errors.jsonl'

    async def run_task():
        try:
            logger.info(f"Starting task: {task['task_id']}")
            if args.debug:
                logger.debug(f"Task details: {json.dumps(task, indent=2)}")
            
            env = MultiAgentEnvironment(
                num_agents=args.num_agents, 
                agent_llm_config=agent_llm_config,
                assistant_llm_config=assistant_config,
                max_rounds=args.max_rounds,
                output_dir=output_dir,
                threshold=args.threshold,
                failure_prob=args.failure_prob,
                metrics=tuple(args.metrics)
            )
            result = await env.run(task)

            costs = env.assistant.llm.cost_manager.get_costs()

            logger.info(f"Task {task['task_id']} completed successfully")

            # Save result
            safe_write_to_file(results_file, result)

            # Save costs
            cost_data = {
                "task_id": task['task_id'],
                "total_prompt_tokens": costs.total_prompt_tokens,
                "total_completion_tokens": costs.total_completion_tokens,
                "total_cost": costs.total_cost,
                "total_budget": costs.total_budget
            }
            safe_write_to_file(costs_file, cost_data)

            return {
                'success': True,
                'task_id': task['task_id'],
                'costs': cost_data
            }
        
        except ValueError as ve:
            if "list.remove(x): x not in list" in str(ve):
                logger.warning(f"Task {task['task_id']} encountered a memory management issue: {str(ve)}")
                error_data = {
                    "task_id": task['task_id'],
                    "error": str(ve),
                    "traceback": traceback.format_exc()
                }
                safe_write_to_file(error_file, error_data)
                # clear gpu memory
                torch.cuda.empty_cache()
                return {
                    'success': False,
                    'task_id': task['task_id'],
                    'error': str(ve)
                }
            else:
                raise
        except Exception as e:
            logger.error(f"Task {task['task_id']} failed: {str(e)}")
            logger.error(f"Traceback: {traceback.format_exc()}")
            error_data = {
                "task_id": task['task_id'],
                "error": str(e),
                "traceback": traceback.format_exc()
            }
            safe_write_to_file(error_file, error_data)
            return {
                'success': False,
                'task_id': task['task_id'],
                'error': str(e)
            }

    # Run the task in the current event loop if there is one, otherwise create a new one
    loop = asyncio.get_event_loop() if asyncio.get_event_loop().is_running() else asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    return loop.run_until_complete(run_task())


def main():
    load_dotenv()
    args = parse_arguments()

    if args.no_cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''

    timestamp_str = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    output_dir = setup_output_directory(args, timestamp_str)
    logger = get_logger("multi_agent_main", output_dir, args.debug)

    try:
        config = load_config(args.config)
        save_config(args, output_dir, timestamp_str)
        agent_config = config.get('agent', {})
        agents_config = config.get('agents', [])
        assistant_config = config.get('assistant', {})
        logger.info("Starting Multi-Agent run")

        if agents_config:
            agent_llm_config = distribute_configs(agents_config, args.num_agents)
        else:
            agent_llm_config = agent_config

        dataset = get_dataset(args, "results")

        with managed_process_pool(args.processes) as pool:
            results = []
            for task in dataset:
                if shutdown_flag.is_set():
                    logger.info("Shutdown flag set. Stopping task processing.")
                    break
                result = pool.apply_async(process_task, (task, args, agent_llm_config, assistant_config, output_dir))
                results.append(result)
            
            # Wait for all tasks to complete
            results = [r.get() for r in results]

        failed_tasks = [result['task_id'] for result in results if not result['success']]

        if failed_tasks and not shutdown_flag.is_set():
            logger.info(f"Retrying {len(failed_tasks)} failed tasks")
            retry_tasks = [task for task in dataset if task['task_id'] in failed_tasks]
            with managed_process_pool(args.processes) as pool:
                retry_results = []
                for task in retry_tasks:
                    if shutdown_flag.is_set():
                        logger.info("Shutdown flag set. Stopping retry processing.")
                        break
                    result = pool.apply_async(process_task, (task, args, agent_llm_config, assistant_config, output_dir))
                    retry_results.append(result)
                
                # Wait for all retry tasks to complete
                retry_results = [r.get() for r in retry_results]
            
            still_failed = [result['task_id'] for result in retry_results if not result['success']]
            
            if still_failed:
                with open(os.path.join(output_dir, 'failed_tasks.json'), 'w') as f:
                    json.dump(still_failed, f, indent=2)
                logger.info(f"Saved {len(still_failed)} failed tasks to failed_tasks.json")

        # Calculate and save total cost
        total_cost = sum(result['costs']['total_cost'] for result in results + retry_results if result.get('success') and 'costs' in result)
        with open(os.path.join(output_dir, 'costs.jsonl'), 'a') as f:
            json.dump({"task_id": "TOTAL", "total_cost": total_cost}, f, ensure_ascii=False)
            f.write('\n')

        logger.info("Multi-Agent run completed")
        logger.info(f"Total cost: {total_cost}")

    except Exception as e:
        logger.error(f"An unexpected error occurred: {str(e)}")
        logger.error(f"Traceback: {traceback.format_exc()}")
    finally:
        # Ensure all asyncio resources are properly closed
        asyncio.get_event_loop().close()
        
        clear_gpu_memory()
        logger.info("Final GPU memory clear completed")

if __name__ == "__main__":
    main()