import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.stdout.reconfigure(encoding='utf-8')

import json
import math
from tqdm import tqdm
import time
from hegflow.utils.globals import Time
from pathlib import Path

import asyncio
from typing import Union,Literal,Optional,Iterator,List,Any,Dict
import argparse
import random
import torch
import torch.nn.functional as F
from typing import Iterator
import pandas as pd
import numpy as np
import time
import asyncio
from typing import List
import copy
from hegflow.graph.graph import Graph
from hegflow.graph.edge import Edge
from dataset.MATH_dataset import MATHDataset
from dataset.GSM8K_dataset import GSM8KDataset

from hegflow.utils.globals import PromptTokens, CompletionTokens
from experiments.accuracy import Accuracy,MATH_Accuracy
from hegflow.utils.globals import Cost, PromptTokens, CompletionTokens
from hegflow.utils.utils import nuclear_norm,frobenius_norm

async def train(graph:Graph,     
            dataset,  
            num_iters:int=100,  
            num_rounds:int=2,     
            lr:float=0.1,
            lr_1:float=0.5,
            batch_size:int = 4,    
            imp_per_iters: int = 5,
            pruning_rate: float = 0.05,
            num_node: int = 10,
            lambda_node_sparsity: float = 0.1,
            num_nodes_to_keep: int = 6,
            args=None,
            kwargs=None,   
          ) -> None:
    
    def infinite_data_loader() -> Iterator[Dict[str, Any]]:  
        while True:
            for idx in range(len(dataset)):
                record = dataset[idx]
                yield record

    loader = infinite_data_loader()  
    correct_count = 0
    total_count = 0
    total_num_agents = num_node[0]
    if args.dec:  
        graph.optimized_spatial=False   
        graph.optimized_temporal=False
       
        if not graph.diff:
           
            params_to_optimize_stage1 = []
            if graph.optimized_spatial:
                params_to_optimize_stage1.append(graph.spatial_logits_1)
            if graph.optimized_temporal:
                params_to_optimize_stage1.append(graph.temporal_logits_1)
           
            if graph.optimized_reflection:
                params_to_optimize_stage1.append(graph.critic_logits) 
            optimizer = torch.optim.Adam(params_to_optimize_stage1, lr=lr_1)
        else:
            
            params_to_optimize_stage1_diff = []
            if graph.optimized_spatial:
               
                params_to_optimize_stage1_diff.extend(list(graph.spatial_logits_1))
            if graph.optimized_temporal:
                params_to_optimize_stage1_diff.extend(list(graph.temporal_logits_1))
           
            if graph.optimized_reflection:
                params_to_optimize_stage1_diff.extend(list(graph.critic_logits)) 
            optimizer = torch.optim.Adam(params_to_optimize_stage1_diff,lr=lr_1)
        
        for i_iter in range(num_iters):  
            print(f"Train {i_iter}", 80*'-')
            start_ts = time.time() 
            correct_answers = []
            answer_log_probs = []  
            add_losses = []
            loss_list: List[torch.Tensor] = []
            utilities: List[float] = []
            answers: List[str] = []

            for i_record, record in zip(range(imp_per_iters), loader):  
                realized_graph = copy.deepcopy(graph)
                realized_graph.spatial_logits_1 = graph.spatial_logits_1
                
                realized_graph.temporal_logits_1 = graph.temporal_logits_1
                
                # if graph.dec_1:
                #     realized_graph.decision_logits = graph.decision_logits
                
                # print(graph.spatial_logits)
                if not graph.diff:
                    spatial_matrix_train = realized_graph.spatial_logits_1. reshape((sum(args.node_nums),sum(args.node_nums))) 
                    temporal_matrix_train = realized_graph.temporal_logits_1. reshape((sum(args.node_nums),sum(args.node_nums)))
                else:
                    spatial_matrix_train = [param.reshape((sum(args.node_nums), sum(args.node_nums))) for param in realized_graph.spatial_logits_1]
                    temporal_matrix_train = [param.reshape((sum(args.node_nums), sum(args.node_nums))) for param in realized_graph.temporal_logits_1]
                    

                spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"],dtype=torch.float32).reshape((sum(args.node_nums),sum(args.node_nums))) 
                temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"],dtype=torch.float32).reshape((sum(args.node_nums),sum(args.node_nums)))
               
                if not graph.diff:
                    loss_s = nuclear_norm(spatial_matrix_train)
                    loss_t = nuclear_norm(temporal_matrix_train)
                    frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
                    frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)
                    loss_cs = connectivity_loss(spatial_matrix_train)  
                    loss_ct = connectivity_loss(temporal_matrix_train)
                    
                else:
                    loss_s = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in spatial_matrix_train]))
                    loss_t = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in temporal_matrix_train]))
                    loss_cs = torch.mean(torch.stack([connectivity_loss(matrix) for matrix in spatial_matrix_train]))
                    loss_ct = torch.mean(torch.stack([connectivity_loss(matrix) for matrix in temporal_matrix_train]))
                    frob_loss_s = torch.mean(torch.stack([frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train]))
                    frob_loss_t = torch.mean(torch.stack([frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train]))
                
                node_strengths_per_round = [] 
                num_rounds_in_graph = len(spatial_matrix_train)  
                for r_idx in range(num_rounds_in_graph):
                    spatial_logits_round = spatial_matrix_train[r_idx]
                    spatial_probs_round = torch.sigmoid(spatial_logits_round)
                    node_activity_spatial = torch.sum(spatial_probs_round, dim=0) + \
                                            torch.sum(spatial_probs_round, dim=1)
                    node_activity_temporal = torch.zeros_like(node_activity_spatial)

                    if r_idx < len(temporal_matrix_train): 
                        temporal_logits_outgoing = temporal_matrix_train[r_idx]
                        temporal_probs_outgoing = torch.sigmoid(temporal_logits_outgoing)
                        node_activity_temporal += torch.sum(temporal_probs_outgoing, dim=1)

                    if r_idx > 0: 
                        temporal_logits_incoming = temporal_matrix_train[r_idx-1]
                        temporal_probs_incoming = torch.sigmoid(temporal_logits_incoming)
                        node_activity_temporal += torch.sum(temporal_probs_incoming, dim=0)
                    
                    node_strengths_per_round.append(node_activity_spatial + node_activity_temporal)

                if node_strengths_per_round:
                    avg_node_strength_across_rounds = torch.mean(torch.stack(node_strengths_per_round), dim=0)
                else:
                    avg_node_strength_across_rounds = torch.zeros(total_num_agents, device=spatial_matrix_train[0].device)                

                sorted_strengths, _ = torch.sort(avg_node_strength_across_rounds, descending=True)


                exp_slope = 0.5 
                exp_shift = float(args.num_nodes_to_keep) 
                epsilon = 1e-8 

                indices_k = torch.arange(total_num_agents, device=sorted_strengths.device, dtype=torch.float32)
                penalty_weights = torch.exp(exp_slope * torch.relu(indices_k - exp_shift))

                raw_penalties = -torch.log(1 - sorted_strengths + epsilon)


                node_sparsity_loss = torch.sum(penalty_weights * raw_penalties)
                
                add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta) + \
                       lambda_node_sparsity * node_sparsity_loss
               
                input_dict = dataset.record_to_input(record)
                #print(input_dict)
                if args.dec:
                    answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,num_rounds,skip=True)))
                else:
                    answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,num_rounds)))
                correct_answer = dataset.record_to_target_answer(record)
                correct_answers.append(correct_answer)
                add_losses.append(add_loss)
                
            raw_results = await asyncio.gather(*answer_log_probs)   
            raw_answers, log_probs, selected_spatial_edges_all, selected_temporal_edges_all = zip(*raw_results)
            
            for raw_answer, log_prob, add_loss, correct_answer, selected_spatial_edges,selected_temporal_edges in zip(raw_answers, log_probs, add_losses, correct_answers, selected_spatial_edges_all, selected_temporal_edges_all):
                answer = dataset.postprocess_answer(raw_answer)  
                
                answers.append(answer)
                assert isinstance(correct_answer, str), \
                        f"String expected but got {correct_answer} of type {type(correct_answer)} (1)"
                accuracy = MATH_Accuracy()
                

                accuracy.update(answer, correct_answer)
                utility = accuracy.get()
                if(utility== 1.0):
                    correct_count += 1
                    update_edge_correct_count(graph,selected_spatial_edges,selected_temporal_edges)

                total_count += 1
                utilities.append(utility)
                single_loss = - log_prob * utility
                loss_list.append(single_loss+add_loss)
                print(f"correct answer:{correct_answer}")
        
            total_loss = torch.mean(torch.stack(loss_list))  
            optimizer.zero_grad() 
            total_loss.backward()
            optimizer.step()
            if not graph.diff:
                spatial_probs = torch.sigmoid(graph.spatial_logits_1)
                temporal_probs = torch.sigmoid(graph.temporal_logits_1)
            else:
                spatial_probs = [torch.sigmoid(logit) for logit in graph.spatial_logits_1]
                temporal_probs = [torch.sigmoid(logit) for logit in graph.temporal_logits_1]
            
            print("raw_answers:",raw_answers)
            print("answers:",answers)
            print(f"Batch time {time.time() - start_ts:.3f}")
            print("utilities:", utilities)
            print("loss:", total_loss.item())
 

      
        graph.update_masks_stage_one(num_node, num_nodes_to_keep)

    loader = infinite_data_loader()

    params_to_optimize_stage2 = []
    if graph.optimized_spatial:
        if not graph.diff:
            params_to_optimize_stage2.append(graph.spatial_logits)
        else:
            params_to_optimize_stage2.extend(list(graph.spatial_logits))
    if graph.optimized_temporal:
        if not graph.diff:
            params_to_optimize_stage2.append(graph.temporal_logits)
        else:
            params_to_optimize_stage2.extend(list(graph.temporal_logits))
    if graph.optimized_reflection:
        if not graph.diff:
            params_to_optimize_stage2.append(graph.critic_logits)
        else:
            params_to_optimize_stage2.extend(list(graph.critic_logits))

    optimizer = torch.optim.Adam(params_to_optimize_stage2, lr=lr)
    
    graph.optimized_spatial=True
    graph.optimized_temporal=True
    

    for i_iter in range(batch_size):
        print(f"Train {i_iter}", 80*'-')
        start_ts = time.time()
        correct_answers = []
        answer_log_probs = []
        add_losses = []
        for i_record, record in zip(range(imp_per_iters), loader): 
            realized_graph = copy.deepcopy(graph)
            realized_graph.spatial_logits = graph.spatial_logits
            realized_graph.temporal_logits = graph.temporal_logits
            # if graph.dec_1:
            #     realized_graph.decision_logits = graph.decision_logits
            
            # print(graph.spatial_logits)
            if not graph.diff:
                spatial_matrix_train = realized_graph.spatial_logits.reshape((sum(args.node_nums),sum(args.node_nums)))
                temporal_matrix_train = realized_graph.temporal_logits.reshape((sum(args.node_nums),sum(args.node_nums)))
                # spatial_matrix_train = remove_zero_rows_and_columns(spatial_matrix_train)
                # temporal_matrix_train = remove_zero_rows_and_columns(temporal_matrix_train)
            else:
                spatial_matrix_train = [param.reshape((sum(args.node_nums), sum(args.node_nums))) for param in realized_graph.spatial_logits]
                temporal_matrix_train = [param.reshape((sum(args.node_nums), sum(args.node_nums))) for param in realized_graph.temporal_logits]
            

            spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"],dtype=torch.float32).reshape((sum(args.node_nums),sum(args.node_nums)))
            temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"],dtype=torch.float32).reshape((sum(args.node_nums),sum(args.node_nums)))
            # spatial_matrix_fixed = spatial_matrix_fixed[:4,:4]
            # temporal_matrix_fixed = temporal_matrix_fixed[:4,:4]
            if not graph.diff:
                loss_s = nuclear_norm(spatial_matrix_train)
                loss_t = nuclear_norm(temporal_matrix_train)
                frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
                frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)
                loss_cs = connectivity_loss(spatial_matrix_train)
                loss_ct = connectivity_loss(temporal_matrix_train)
                # print(loss_cs)
            else:
                # loss_s = sum(nuclear_norm(matrix) for matrix in spatial_matrix_train)
                # loss_t = sum(nuclear_norm(matrix) for matrix in temporal_matrix_train)
                # frob_loss_s = sum(frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train)
                # frob_loss_t = sum(frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train)
                loss_s = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in spatial_matrix_train]))
                loss_t = torch.mean(torch.stack([nuclear_norm(matrix) for matrix in temporal_matrix_train]))
                loss_cs = torch.mean(torch.stack([connectivity_loss(matrix) for matrix in spatial_matrix_train]))
                loss_ct = torch.mean(torch.stack([connectivity_loss(matrix) for matrix in temporal_matrix_train]))
                frob_loss_s = torch.mean(torch.stack([frobenius_norm(spatial_matrix_fixed, matrix) for matrix in spatial_matrix_train]))
                frob_loss_t = torch.mean(torch.stack([frobenius_norm(temporal_matrix_fixed, matrix) for matrix in temporal_matrix_train]))
            add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta)
            # if graph.diff:
            # add_loss = 0
            input_dict = dataset.record_to_input(record)
            print(input_dict)
            if args.dec:
                answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,num_rounds)))
            else:
                answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,num_rounds)))
            correct_answer = dataset.record_to_target_answer(record)
            correct_answers.append(correct_answer)
            add_losses.append(add_loss)
            
        raw_results = await asyncio.gather(*answer_log_probs)
        raw_answers, log_probs, selected_spatial_edges_all, selected_temporal_edges_all = zip(*raw_results)
        loss_list: List[torch.Tensor] = []
        utilities: List[float] = []
        answers: List[str] = []
        
        for raw_answer, log_prob, add_loss, correct_answer, selected_spatial_edges,selected_temporal_edges in zip(raw_answers, log_probs, add_losses, correct_answers, selected_spatial_edges_all, selected_temporal_edges_all):
            answer = dataset.postprocess_answer(raw_answer)
            
            answers.append(answer)
            assert isinstance(correct_answer, str), \
                    f"String expected but got {correct_answer} of type {type(correct_answer)} (1)"
            accuracy = MATH_Accuracy()
            
            accuracy.update(answer, correct_answer)
            utility = accuracy.get()
            if(utility== 1.0):
                    correct_count += 1
                    update_edge_correct_count(graph,selected_spatial_edges,selected_temporal_edges)
            total_count += 1
            utilities.append(utility)
            single_loss = - log_prob * utility
            loss_list.append(single_loss+add_loss)
            print(f"correct answer:{correct_answer}")
    
        total_loss = torch.mean(torch.stack(loss_list))
        optimizer.zero_grad() 
        total_loss.backward()
        optimizer.step()
        if not graph.diff:
            spatial_probs = torch.sigmoid(graph.spatial_logits)
            temporal_probs = torch.sigmoid(graph.temporal_logits)
        else:
            spatial_probs = [torch.sigmoid(logit) for logit in graph.spatial_logits]
            temporal_probs = [torch.sigmoid(logit) for logit in graph.temporal_logits]
        
        print("raw_answers:",raw_answers)
        print("answers:",answers)
        print(f"Batch time {time.time() - start_ts:.3f}")
        print("utilities:", utilities)
        print("loss:", total_loss.item()) 

     
        
   
      
def update_edge_correct_count(graph: Graph, selected_spatial_edges: List[tuple[str, str]],selected_temporal_edges: List[tuple[str, str]]):
 
    for source_id, target_id in selected_spatial_edges:
        if (source_id, target_id) in graph.spatial_edges:
            graph.spatial_edges[(source_id, target_id)].mark_correct()
    for source_id, target_id in selected_temporal_edges:
        if (source_id, target_id) in graph.temporal_edges:
            graph.temporal_edges[(source_id, target_id)].mark_correct()


def connectivity_loss(A: torch.Tensor) -> torch.Tensor:

    expA = torch.matrix_exp(A)

    penalty = torch.trace(expA) - A.shape[0]
    return penalty

def set_seed(seed):
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    random.seed(seed)
    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def remove_zero_rows_and_columns(matrix,row_to_delete,col_to_delete):
    new_matrix = torch.cat([matrix[:row_to_delete], matrix[row_to_delete+1:]])
    new_matrix = torch.cat([new_matrix[:, :col_to_delete], new_matrix[:, col_to_delete+1:]], dim=1)

    return new_matrix

async def evaluate(
        graph:Graph,
        dataset,
        num_rounds:int = 1,
        limit_questions: Optional[int] = None,
        eval_batch_size: int = 4,
        dec: bool = False,
        args=None,
        ) -> float:

    print(f"Evaluating AgentGraph on {dataset.__class__.__name__}")
    
    graph.spatial_logits.requires_grad_ = False
    graph.temporal_logits.requires_grad_ = False
    
    accuracy = MATH_Accuracy()
    def eval_loader(batch_size: int) -> Iterator[List[Any]]:
        records = []
        for i_record, record in enumerate(dataset):
            if limit_questions is not None:
                if i_record >= limit_questions:
                    break
            records.append(record)
            if len(records) >= batch_size:
                yield records
                records = []
        if len(records) > 0:
            yield records
        return
    data_len = min(len(dataset), limit_questions) if limit_questions is not None else len(dataset)
    num_batches = int(math.ceil(data_len / eval_batch_size))

    data=[]
    current_time = Time.instance().value or time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    result_dir = Path(f"/result/mmlu-pro")
    result_dir.mkdir(parents=True, exist_ok=True)
    result_file = result_dir / f"{args.domain}_{current_time}.json"

    for i_batch, record_batch in tqdm(enumerate(eval_loader(batch_size=eval_batch_size)), total=num_batches):
        print(80*'-')

        start_ts = time.time()
        answer_log_probs = []
        
        for record in record_batch:
            realized_graph = copy.deepcopy(graph)
            realized_graph.spatial_logits = graph.spatial_logits
            realized_graph.temporal_logits = graph.temporal_logits
            input_dict = dataset.record_to_input(record)

            answer_log_probs.append(asyncio.create_task(realized_graph.arun(input_dict,num_rounds,case=True))) 
        raw_results = await asyncio.gather(*answer_log_probs)
        raw_answers, log_probs, all_answers = zip(*raw_results)
        
        print(f"Batch time {time.time() - start_ts:.3f}")
        for raw_answer, record, all_answer in zip(raw_answers, record_batch, all_answers):
            print("Raw answer:", raw_answer)
            answer = dataset.postprocess_answer(raw_answer)
            
            print("Postprocessed answer:", answer)
            correct_answer = dataset.record_to_target_answer(record)
            print("Correct answer:", correct_answer)
            
            accuracy.update(answer, correct_answer)
            accuracy.print()
            updated_item = {
                "Question": dataset.record_to_input(record)['task'],
                "Answer": correct_answer,
                "All_answers": all_answer,
                "Response": raw_answer,
            }
            data.append(updated_item)
        with open(result_file, 'w',encoding='utf-8') as file:
            json.dump(data, file, indent=4)


    accuracy.print()
    print("Done!")

    return accuracy.get()

def dump_eval_results(self, dct: Dict[str, Any]) -> None:
    if self._art_dir_name is not None:
        eval_json_name = os.path.join(self._art_dir_name, "evaluation.json")
        with open(eval_json_name, "w") as f:
            json.dump(dct, f)

def parse_args():
    parser = argparse.ArgumentParser(description="Process some parameters.")

    parser.add_argument('--mode', type=str, default='FullConnected',
                        choices=['DirectAnswer', 'FullConnected', 'Random', 'Chain', 'Debate', 'Layered','Star', 'Mesh',
                                 'FakeFullConnected','FakeRandom','FakeChain','FakeStar','FakeMesh','FakeAGRandom','FakeAGFull'],
                        help="Mode of operation. Default is 'FullConnected'.")
    parser.add_argument('--lr', type=float, default=0.1,
                        help="learning rate")
    parser.add_argument('--delta', type=float, default=0.1,
                        help="noise level")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="batch size")
    parser.add_argument('--agent_names', nargs='+', type=str, default=['AgentGraph_MATH'],
                        help='Specify agent names as a list of strings')
    parser.add_argument('--node_nums', nargs='+', type=int, default=[10],
                        help='Specify the number of agents for each name in agent_names')   
    parser.add_argument('--tool_nums', type=int, default=[10],    
                        help="Number of tools. Default 4.")
    parser.add_argument('--optimized_reflection', action='store_true', default=True,
                        help="Whether to optimize reflection probabilities.")
    parser.add_argument('--initial_critic_probability', type=float, default=0.5,
                        help="Initial probability for enabling reflection.")
    parser.add_argument('--reflection', default=False ,help="whether to use reflection")  
    parser.add_argument('--num_iterations', type=int, default=10, 
                        help="Number of optimization iterations. Default 10.")
    parser.add_argument('--imp_per_iterations', type=int, default=5)
    parser.add_argument('--num_rounds',type=int , default=2,
                        help="Number of optimization/inference rounds for one query")
    parser.add_argument('--pruning_rate', type=float, default=0.15,
                        help="The Rate of Pruning. Default 0.05.")
    parser.add_argument('--llm_name', type=str, default="gpt-4o-mini", 
                        help="Model name, None runs the default gpt-4o-mini")
    parser.add_argument('--domain', type=str, default="math",
                        help="Domain (the same as dataset name), default 'math'")
    parser.add_argument('--decision_method', type=str, default="FinalRefer",
                        help="the decision method of the final node")
    parser.add_argument('--optimized_spatial',action='store_true',default=True)
    parser.add_argument('--optimized_temporal',action='store_true',default=True)
    parser.add_argument('--optimized_reflection',action='store_true',default=True)
    parser.add_argument('--diff',action='store_true',default=True)
    parser.add_argument('--dec',action='store_true',default=True)
    parser.add_argument('--cot',action='store_true',default=True)

    parser.add_argument('--lambda_node_sparsity', type=float, default=0.1, 
                    help='Weight for the node sparsity regularization loss.')
    parser.add_argument('--num_nodes_to_keep', type=int, default=6, 
                    help='Target number of important nodes to keep.')
    
    parser.add_argument('--data_dir', type=str, default="dataset/MATH",
                        help="Path to the MATH dataset directory.")
    
    args = parser.parse_args()
    result_path = "result"
    os.makedirs(result_path, exist_ok=True)
    if len(args.agent_names) != len(args.node_nums):
        parser.error("The number of agent names must match the number of agent counts.")
    return args

async def main():
    args = parse_args()
    mode = args.mode
    decision_method = args.decision_method
    agent_names = [name for name,num in zip(args.agent_names,args.node_nums) for _ in range(num)]
    # print(agent_names)['AgentGraph','AgentGraph',...]
    kwargs = get_kwargs(mode,len(agent_names))
    limit_questions = 300
    
    graph = Graph(domain=args.domain,
                  llm_name=args.llm_name,
                  agent_names=agent_names,
                  decision_method=decision_method,
                  optimized_spatial=args.optimized_spatial,
                  optimized_temporal=args.optimized_temporal,
                  rounds=args.num_rounds,
                  diff=args.diff,
                  dec=args.dec,
                  reflection=args.reflection, 
                  **kwargs)
    
    dataset_test = MATHDataset(split='test', data_dir=args.data_dir)
    dataset_train = MATHDataset(split='train', data_dir=args.data_dir)
    
    if args.optimized_spatial or args.optimized_temporal:
        
        await train(graph=graph,
                    dataset=dataset_train,
                    num_iters=args.num_iterations,
                    
                    num_rounds=args.num_rounds,
                    lr=args.lr,
                    batch_size=100,  
                    imp_per_iters=args.imp_per_iterations,
                    pruning_rate=args.pruning_rate,
                    num_node = args.node_nums,
                    lambda_node_sparsity=args.lambda_node_sparsity,
                    num_nodes_to_keep=args.num_nodes_to_keep,
                    args=args,
                    kwargs=kwargs,)
         
    print_parameter_list("Final spatial logits", graph.spatial_logits)
    print_parameter_list("Final temporal logits", graph.temporal_logits)
    print_parameter_list("Final spatial masks", graph.spatial_masks)
    print_parameter_list("Final temporal masks", graph.temporal_masks)

    if not args.diff:
        print("Final spatial sparsity:",graph.spatial_masks.sum()/graph.spatial_masks.numel())
        print("Final temporal sparsity:",graph.temporal_masks.sum()/graph.temporal_masks.numel())
    else:
        spatial_sparsity = torch.mean(torch.stack([mask.sum() / mask.numel() for mask in graph.spatial_masks]))
        print("Spatial sparsity (mean):", spatial_sparsity)

        temporal_sparsity = torch.mean(torch.stack([mask.sum() / mask.numel() for mask in graph.temporal_masks]))
        print("Temporal sparsity (mean):", temporal_sparsity)

    PromptTokens.instance().reset()
    CompletionTokens.instance().reset()

    if args.dec:
        score = await evaluate(graph=graph,
                               dataset=dataset_test,
                               num_rounds=args.num_rounds,
                               limit_questions=limit_questions,
                               eval_batch_size=args.batch_size,
                               dec=True,
                               args=args)
    else:
        score = await evaluate(graph=graph,dataset=dataset_test,num_rounds=args.num_rounds,limit_questions=limit_questions,eval_batch_size=args.batch_size,args=args)
    print(f"Score: {score}")

def print_parameter_list(name, param_list):

    print(f"{name}:")
    if isinstance(param_list, torch.nn.ParameterList):
        for i, param in enumerate(param_list):
            print(f"  Matrix {i+1}:")
            numpy_array = param.detach().cpu().numpy()  
            print(numpy_array)
    elif isinstance(param_list, torch.Tensor):
         numpy_array = param_list.detach().cpu().numpy()  
         print(numpy_array)
    else:
        print(f"Type ERROR {type(param_list)}")

def get_kwargs(mode:Union[Literal['DirectAnswer'],Literal['FullConnected'],Literal['Random'],Literal['Chain'],Literal['Debate'],Literal['Layered'],Literal['Star'],Literal['Mesh'],
                          Literal['FakeFullConnected'],Literal['FakeRandom'],Literal['FakeChain'],Literal['FakeStar'],Literal['FakeMesh'],Literal['FakeAGRandom'],Literal['FakeAGFull']],
               N:int):
    initial_spatial_probability: float = 0.5
    fixed_spatial_masks:List[List[int]] = None
    initial_temporal_probability: float = 0.5
    fixed_temporal_masks:List[List[int]] = None
    node_kwargs = None
    
    def generate_layered_graph(N,layer_num=2):
        adj_matrix = [[0]*N for _ in range(N)]
        base_size = N // layer_num
        remainder = N % layer_num
        layers = []
        for i in range(layer_num):
            size = base_size + (1 if i < remainder else 0)
            layers.extend([i] * size)
        # random.shuffle(layers)
        for i in range(N):
            current_layer = layers[i]
            for j in range(N):
                if layers[j] == current_layer + 1:
                    adj_matrix[i][j] = 1
        return adj_matrix
    
    def generate_mesh_graph(N):
        adj_matrix = [[0] * N for _ in range(N)]
        for i in range(0, N):
            for j in range(i+1,N):
                adj_matrix[i][j] = 1
        return adj_matrix
    
    def generate_star_graph(N):
        adj_matrix = [[0] * N for _ in range(N)]
        for i in range(1,N):
            adj_matrix[0][i] = 1
        return adj_matrix
    
    if mode=='DirectAnswer':
        fixed_spatial_masks = [[0]]
        fixed_temporal_masks = [[0]]
        node_kwargs = [{'role':'Normal'}]
    elif mode=='FullConnected' or mode == 'FakeFullConnected' or mode=='FakeAGFull':
        fixed_spatial_masks = [[1 if i!=j else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 for _ in range(N)] for _ in range(N)]
    elif mode=='Random' or mode == 'FakeRandom' or mode == 'FakeAGRandom':
        fixed_spatial_masks = [[random.randint(0, 1)  if i!=j else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[random.randint(0, 1) for _ in range(N)] for _ in range(N)]
    elif mode=='Chain' or mode == 'FakeChain':
        fixed_spatial_masks = [[1 if i==j+1 else 0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 if i==0 and j==N-1 else 0 for i in range(N)] for j in range(N)]
    elif mode == 'Debate':
        fixed_spatial_masks = [[0 for i in range(N)] for j in range(N)]
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    elif mode == 'Layered':
        fixed_spatial_masks = generate_layered_graph(N)
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    elif mode == 'Mesh' or mode=='FakeMesh':
        fixed_spatial_masks = generate_mesh_graph(N)
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    elif mode == 'Star' or mode=='FakeStar':
        fixed_spatial_masks = generate_star_graph(N)
        fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)]
    
    if 'Fake' in mode and 'AG' not in mode:
        node_kwargs = [{'role':'Fake'} if i % 2 == N % 2 else {'role':'Normal'} for i in range(N)]
    elif 'Fake' in mode and 'AG' in mode:
        node_kwargs = [{'role':'Fake'} if i % 2 == N % 2 else {'role':None} for i in range(N)]
        
    return {"initial_spatial_probability": initial_spatial_probability,
            "fixed_spatial_masks": fixed_spatial_masks,
            "initial_temporal_probability": initial_temporal_probability,
            "fixed_temporal_masks": fixed_temporal_masks,
            "node_kwargs":node_kwargs}    

if __name__ == "__main__":
    asyncio.run(main())