import networkx as nx
import numpy as np
import torch
import sys
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy

from Agent import Agent
from Env import Env
from param_sets import param_sets


def generate_job_name(params):
    return f"job{params['index']}-n={params['n']}-test_n={params['test_n']}-budget={params['budget']}-test_budget={params['test_budget']}-T={params['T']}-epoch={params['num_epoch']}-iteration={params['num_iteration']}"

def run_single_test(model, policy):
    budget_policy, seeding_policy = policy
    _, _, _, _, _, cumulative_reward = model.run_episode(
        print_info=True,
        budget_policy=budget_policy,
        seeding_policy=seeding_policy,
        beam_search=('agent' in seeding_policy)
    )
    return cumulative_reward


def run_test_rounds(original_model, params, job_name, policy_pairs):
    num_workers = os.cpu_count() or 1

    for policy in policy_pairs:
        budget_policy, seeding_policy = policy
        txt_name = f"results/{job_name}-{budget_policy}-{seeding_policy}.txt"
        results = []

        with open(txt_name, 'w', encoding='utf-8') as f:
            original_stdout = sys.stdout
            sys.stdout = f

            print(f"Starting test for {budget_policy}-{seeding_policy}")

            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                futures = []
                for _ in range(params['num_test']):
                    # Create a new model for each task
                    model_copy = deepcopy(original_model)
                    future = executor.submit(
                        run_single_test, model_copy, policy)
                    futures.append(future)

                for i, future in enumerate(as_completed(futures)):
                    cumulative_reward = future.result()
                    results.append(cumulative_reward)
                    print(
                        f"test round: {i}, cumulative_reward: {cumulative_reward}")

            print("\nAll test rounds completed. Results:")

            avg_reward = np.mean(results)
            std_dev = np.std(results)
            print(f"\nSummary:")
            print(f"Average cumulative reward: {avg_reward}")
            print(f"Standard deviation: {std_dev}")
            sys.stdout = original_stdout

def create_model(params):
    g = nx.erdos_renyi_graph(params['n'], 0.01, directed=True)
    env = Env(graph=g, budget=params['budget'], T=params['T'])
    return Agent(env=env)

def save_model(model, params):
    job_name = generate_job_name(params)
    save_path = f"models/{job_name}_model.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

def load_model(params):
    job_name = generate_job_name(params)
    load_path = f"models/{job_name}_model.pth"
    if os.path.exists(load_path):
        model = create_model(params)  
        model.load_state_dict(torch.load(load_path))
        print(f"Model loaded from {load_path}")
        return model
    else:
        print(f"No saved model found at {load_path}")
        return None

def run_training(model, params):
    job_name = generate_job_name(params)
    txt_name = f"results/{job_name}-training.txt"
    original_stdout = sys.stdout

    with open(txt_name, 'w', encoding='utf-8') as f:
        sys.stdout = f

        train_cumulative_reward_list, test_cumulative_reward_list = model.train_model(
            num_epochs=params['num_epoch'], num_iterations=params['num_iteration']
        )

        print("------------------------------ Total Results ------------------------------")
        print("train_cumulative_reward_list: ", train_cumulative_reward_list)
        print("test_cumulative_reward_list: ", test_cumulative_reward_list)

    sys.stdout = original_stdout
    save_model(model, params)
    return model

def run_testing(model, params, policy_pairs=None):
    job_name = generate_job_name(params)
    g = nx.erdos_renyi_graph(params['test_n'], 0.01, directed=True)
    model.set_graph(g)
    model.env.budget = params['test_budget']
    model.env.reset()
    model.eval()

    if policy_pairs is None:
        policy_pairs = [
            ('agent', 'agent'), ('average', 'score'), ('average', 'degree'),
            ('static', 'score'), ('static', 'degree'),
            ('normal', 'score'), ('normal', 'degree'),
        ]

    run_test_rounds(model, params, job_name, policy_pairs)

def run_experiment(params):
    model = load_model(params)
    if model is None:
        model = create_model(params)
        model = run_training(model, params)
    run_testing(model, params)

if __name__ == "__main__":
    for params in param_sets:
        run_experiment(params)
