from functools import partial
from sys import argv

import torch as th

from algorithms import EUPG
from common import *
from common import WandbLogger
from environments import MOResourceGathering

if __name__ == "__main__":
    seed = int(argv[1])
    seed_everything(seed)
    reward_dim = 4
    n_agents = 2
    partial_observability = True
    centralized_controller = True
    recurrent_policy = True
    env = MOResourceGathering(
        env_size="small", 
        sampling_strategy="uniform",
        n_agents=n_agents,
        agents_finite_size_bags=True, 
        reward_dim=reward_dim,
        agents_bags_size=[4]*n_agents,
        partial_observability=partial_observability,
        centralized_controller=centralized_controller,
    )

    eval_env = MOResourceGathering(
        env_size="small", 
        sampling_strategy="uniform",
        n_agents=n_agents,
        agents_finite_size_bags=True,
        reward_dim=reward_dim,
        agents_bags_size=[4]*n_agents,
        partial_observability=partial_observability,
        centralized_controller=centralized_controller,
    )

    env_name = env.env_name

    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    print(device)
    baseline = False
    standardize = False
    timesteps = int(1e7)
    scalarization="nash"
    config = {
        "env": env_name,
        "timesteps": timesteps,
        "Algorithm": f'{"Standardized-" if standardize and not baseline else ""}centralized{"-partial" if partial_observability else ""}{"-gru" if recurrent_policy else ""}-EUPG{"-with-baseline" if baseline else ""}-{scalarization}',
        "scalarization": scalarization,
        "device": device,
    }
    obs_size = (
        sum(env.observation_space[i].shape[0] for i in range(n_agents))
        if partial_observability
        else env.observation_space.shape[0]
    )
    a = EUPG(
        reward_dim,
        env,
        obs_size,
        [64],
        env.action_space.n,
        nash_scalarization,
        scalarization_weights=th.tensor([1 / 2**i for i in range(reward_dim)]),
        device=device,
        use_baseline=baseline,
        standardization=standardize,
        recurrent_policy=recurrent_policy,
    )
    logger = WandbLogger(
        f"AAMAS 2025-{env_name}",
        f'{config["Algorithm"]}-{seed}',
        config,
    )
    print(f"Training started on {device}")
    results_dir = f"results/{env_name}/{config['Algorithm']}/{argv[1]}"
    a.train(
        timesteps,
        eval_env,
        eval_freq=50000,
        n_evals=100,
        log=True,
        logger=logger,
        save_results=True,
        results_dir=results_dir,
    )
    a.save_best_policy(f"{results_dir}/policy_weights.pt")
