import random
import json
import torch
from typing import List, Dict, Tuple
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
import numpy as np 
from agent_prompts import SYS_PROMPT, ATTACKER_SYS_PROMPT
from utils import get_tool_dict
from gen_memory_attack_data import gen_poisonrag_data
from agent_async import AsyncAgentGraph  
import os


def generate_task_allocation_graph(num_nodes: int, sparsity: float, num_graphs: int):

  
    if not (0 <= sparsity <= 1):
        raise ValueError("Sparsity must be between 0 and 1.")
    
    graph_dataset = []
    
    for _ in range(num_graphs):
        # 生成有向无环图(DAG)以表示任务依赖关系
        adj_matrix = np.tril(np.random.rand(num_nodes, num_nodes) <= sparsity, k=-1)
        adj_matrix = adj_matrix.astype(int)
        
        graph_dataset.append(adj_matrix)
    
    return graph_dataset


def generate_agent_roles(num_nodes: int) -> List[str]:
    """
  
    """
    roles = ["coordinator", "worker", "reviewer", "retriever", "reasoner"]
    

    agent_roles = ["coordinator"]
    

    for _ in range(num_nodes - 1):
        agent_roles.append(random.choice(roles))
    
    return agent_roles


def generate_async_agent_graph_dataset(
    num_nodes: int, 
    sparsity: float, 
    num_graphs: int, 
    num_attackers: int
) -> List[Dict]:
    """
   
    Args:
        
    
    Returns:
        List of dataset entries
    """
    assert num_attackers < num_nodes, "Number of attackers must be less than total nodes"
    
    dataset = []
    adj_matrices = generate_task_allocation_graph(num_nodes, sparsity, num_graphs)
    
    for adj_m in adj_matrices:
       
        random_attacker_idxes = random.sample(range(num_nodes), num_attackers)
        
    
        agent_roles = generate_agent_roles(num_nodes)
        
       
        normal_prompt_template = "You are agent_{}, a {} in this multi-agent system. " + SYS_PROMPT
        attacker_prompt_template = "You are agent_{}, a {} in this multi-agent system. " + ATTACKER_SYS_PROMPT
        
        data = {
            "adj_matrix": adj_m,
            "attacker_idxes": random_attacker_idxes,
            "agent_roles": agent_roles,
            "normal_prompt_template": normal_prompt_template,
            "attacker_prompt_template": attacker_prompt_template
        }
        dataset.append(data)
    
    return dataset


def generate_initial_async_data(ag_data: Dict, example: Tuple) -> Dict:
    """
    
    Args:
        ag_data: agent graph configuration
        example: (query, adv_texts, correct_answer, incorrect_answer)
    
    Returns:
        Complete data entry for async execution
    """
    new_data = {}
    
    adj_m = ag_data["adj_matrix"]
    attacker_idxes = ag_data["attacker_idxes"]
    agent_roles = ag_data["agent_roles"]
    normal_prompt_template = ag_data["normal_prompt_template"]
    attacker_prompt_template = ag_data["attacker_prompt_template"]
    
    # role-specific system prompts
    agent_system_prompts = []
    for idx in range(adj_m.shape[0]):
        role = agent_roles[idx]
        
        if idx in attacker_idxes:
            agent_prompt = attacker_prompt_template.format(idx, role)
        else:
            agent_prompt = normal_prompt_template.format(idx, role)
        
        agent_system_prompts.append(agent_prompt)
    
    new_data["adj_matrix"] = adj_m
    new_data["attacker_idxes"] = attacker_idxes
    new_data["agent_roles"] = agent_roles
    new_data["system_prompts"] = agent_system_prompts
    

    new_data["query"] = example[0]
    new_data["adv_texts"] = example[1]
    new_data["correct_answer"] = example[2]
    new_data["incorrect_answer"] = example[3]
    
    return new_data


def generate_async_graph_dataset(args):

  
    cases_dataset = gen_poisonrag_data(args.dataset_path, args.phase)
    

    ag_dataset = generate_async_agent_graph_dataset(
        num_nodes=args.num_nodes,
        sparsity=args.sparsity,
        num_graphs=args.num_graphs,
        num_attackers=args.num_attackers
    )
    

    initial_dataset = []
    for agent_graph in tqdm(ag_dataset, desc="Generate async meta data"):
        for case in cases_dataset:
            initial_data = generate_initial_async_data(agent_graph, case)
            initial_dataset.append(initial_data)
    
    random.shuffle(initial_dataset)
    sampled_initial_dataset = initial_dataset[:args.samples]
    
   
    activation_dir = os.path.join(args.save_dir, "activations")
    execution_trace_dir = os.path.join(args.save_dir, "execution_traces")
    os.makedirs(activation_dir, exist_ok=True)
    os.makedirs(execution_trace_dir, exist_ok=True)
    

    final_dataset = []
    
    for sample_idx, d in enumerate(tqdm(sampled_initial_dataset, desc="Generate async execution data")):
        try:
            adj_m = d["adj_matrix"]
            attacker_idxes = d["attacker_idxes"]
            system_prompts = d["system_prompts"]
            agent_roles = d["agent_roles"]
            query = d["query"]
            context = d["adv_texts"]
            
           
            async_ag = AsyncAgentGraph(
                adj_matrix=adj_m,
                system_prompts=system_prompts,
                agent_roles=agent_roles,
                attacker_idxes=attacker_idxes,
                model_type=args.model_type
            )
            
            # 异步执行
            execution_trace = async_ag.execute_async(
                query=query,
                context=context,
                max_time_steps=args.max_time_steps
            )
            
          
            communication_data = []  
            activation_data = []  
            execution_log = []  
            
            for time_step in range(len(execution_trace)):
                step_data = execution_trace[time_step]
                
             
                executing_agents = step_data["executing_agents"]
                
                
                responses = {}
                for agent_idx in executing_agents:
                    responses[agent_idx] = step_data["responses"][agent_idx]
                
                communication_data.append({
                    "time_step": time_step,
                    "executing_agents": executing_agents,
                    "responses": responses
                })
                
                
                step_activations = {}
                for agent_idx in executing_agents:
                    agent = async_ag.agents[agent_idx]
                    step_activations[agent_idx] = agent.get_activations()
                
                activation_data.append({
                    "time_step": time_step,
                    "activations": step_activations
                })
                
              
                execution_log.append({
                    "time_step": time_step,
                    "executing_agents": executing_agents,
                    "task_queue_states": step_data["task_queues"]
                })
            
         
            d["communication_data"] = communication_data
            d["execution_log"] = execution_log
            d["total_time_steps"] = len(execution_trace)
            d["adj_matrix"] = d["adj_matrix"].tolist()
            
         
            activation_file = os.path.join(activation_dir, f"sample_{sample_idx:04d}.pt")
            torch.save(activation_data, activation_file)
            d["activation_file"] = activation_file
            
          
            trace_file = os.path.join(execution_trace_dir, f"sample_{sample_idx:04d}.json")
            with open(trace_file, "w") as f:
                json.dump(execution_log, f, indent=2)
            d["trace_file"] = trace_file
            
            final_dataset.append(d)
            
        except Exception as e:
            print(f"Error processing sample {sample_idx}: {e}")
            continue
    

    with open(args.save_filepath, "w") as file:
        json.dump(final_dataset, file, indent=2)
    
    print(f"\n{'='*60}")
    print(f"Dataset Generation Complete")
    print(f"{'='*60}")
    print(f"Total samples: {len(final_dataset)}")
    print(f"Saved to: {args.save_filepath}")
    print(f"Activations: {activation_dir}")
    print(f"Execution traces: {execution_trace_dir}")
    print(f"{'='*60}\n")
    
    return final_dataset


if __name__ == "__main__":
    import argparse
    from datetime import datetime
    
    def parse_arguments():
        parser = argparse.ArgumentParser(
            description="Generate async MAS dataset with event-driven execution"
        )
        
        # Dataset paths
        parser.add_argument(
            "--dataset_path", 
            type=str, 
            default="./datasets/msmarco.json",
            help="Path to base dataset"
        )
        parser.add_argument(
            "--dataset", 
            type=str, 
            default="memory_attack",
            help="Dataset type"
        )
        parser.add_argument(
            "--phase", 
            type=str, 
            default="test",
            choices=["train", "val", "test"],
            help="Dataset phase"
        )
        
        # Graph configuration
        parser.add_argument(
            "--num_nodes", 
            type=int, 
            default=8,
            help="Number of agents in the system"
        )
        parser.add_argument(
            "--sparsity", 
            type=float, 
            default=0.2,
            help="Task dependency sparsity (0-1, higher = denser dependencies)"
        )
        parser.add_argument(
            "--num_graphs", 
            type=int, 
            default=20,
            help="Number of random task allocation graphs"
        )
        parser.add_argument(
            "--num_attackers", 
            type=int, 
            default=3,
            help="Number of compromised agents"
        )
        
        # Async execution parameters
        parser.add_argument(
            "--max_time_steps", 
            type=int, 
            default=20,
            help="Maximum time steps for async execution (replaces num_dialogue_turns)"
        )
        
        # Sampling
        parser.add_argument(
            "--samples", 
            type=int, 
            default=40,
            help="Number of samples to generate"
        )
        
        # Model and saving
        parser.add_argument(
            "--model_type", 
            type=str, 
            default="gpt-oss-20b",
            choices=["gpt-oss-20b", "llama3-8b", "deepseek-v3", "qwen3-30b-a3b"],
            help="LLM backbone"
        )
        parser.add_argument(
            "--save_dir", 
            type=str, 
            default="./agent_graph_dataset",
            help="Base directory to save dataset"
        )
        parser.add_argument(
            "--save_filepath", 
            type=str,
            default=None,
            help="Full path to save dataset JSON (auto-generated if None)"
        )
        
        args = parser.parse_args()
        

        args.save_dir = os.path.join(
            args.save_dir, 
            args.dataset, 
            "async",  
            args.phase
        )
        
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        

        if args.save_filepath is None:
            current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = (
                f"{current_time_str}-async_dataset"
                f"-size_{args.samples}"
                f"-nodes_{args.num_nodes}"
                f"-attackers_{args.num_attackers}"
                f"-sparsity_{args.sparsity}"
                f"-maxsteps_{args.max_time_steps}.json"
            )
            args.save_filepath = os.path.join(args.save_dir, filename)
        
        return args
    

    args = parse_arguments()
    

    print(f"\n{'='*60}")
    print(f"Async MAS Dataset Generation Configuration")
    print(f"{'='*60}")
    print(f"Dataset: {args.dataset}")
    print(f"Phase: {args.phase}")
    print(f"Num nodes: {args.num_nodes}")
    print(f"Num attackers: {args.num_attackers}")
    print(f"Sparsity: {args.sparsity}")
    print(f"Max time steps: {args.max_time_steps}")
    print(f"Samples: {args.samples}")
    print(f"Model: {args.model_type}")
    print(f"Save dir: {args.save_dir}")
    print(f"{'='*60}\n")
    

    dataset = generate_async_graph_dataset(args)
    
    print(f"\n✅ Successfully generated {len(dataset)} async samples!")