from functools import partial
from sys import argv

import torch as th

from algorithms import DecEUPG
from common import (
    WandbLogger,
    fair_ratio,
    min_scalarization,
    nash_scalarization,
    owa_scalarization,
    seed_everything,
    min_proportion, 
)
from environments import MOResourceGathering

if __name__ == "__main__":
    seed = int(argv[1])
    seed_everything(seed)

    reward_dim = 4
    n_agents = 2
    partial_observability = True
    local_reward = True
    agent_specific_reward = False
    reward_conditioned = True
    
    
    env = MOResourceGathering(
        env_size="small", 
        sampling_strategy="uniform",
        n_agents=n_agents,
        agents_finite_size_bags=True, 
        reward_dim=reward_dim,
        agents_bags_size=[4] * n_agents,
        partial_observability=partial_observability,
        centralized_controller=False,
        local_reward=local_reward,
        agent_specific_objectives=agent_specific_reward,
        agents_objectives=[[0], [1]],
    )

    eval_env = MOResourceGathering(
        env_size="medium", 
        sampling_strategy="uniform",
        n_agents=n_agents,
        agents_finite_size_bags=True,
        reward_dim=reward_dim,
        agents_bags_size=[4] * n_agents,
        partial_observability=partial_observability,
        centralized_controller=False,
        local_reward=local_reward,
        agent_specific_objectives=agent_specific_reward,
        agents_objectives=[[0], [1]],
    )
    env_name = env.env_name

    device = th.device("cuda" if th.cuda.is_available() else "cpu")

    weight_sharing = False
    baseline = False
    standardize = False
    recurrent_policy = True

    timesteps = int(1e7)
    scalarization = "nash"
        
    if not reward_conditioned: 
        algo = f"{'gru-' if recurrent_policy else ''}Dec-PG"
    else:
        algo = (
            f'{"AgentSpecificReward-" if agent_specific_reward else ""}'
            +f'{"local-" if local_reward else ""}{"gru-" if recurrent_policy else ""}'
            +f'{"shared-parameters-" if weight_sharing else ""}dec-'
            +f'{"Standardized-" if standardize and not baseline else ""}'
            +f'EUPG{"-with-baseline" if baseline else ""}-{scalarization}')
        
    config = {
        "env": env_name,
        "timesteps": timesteps,
        "Algorithm": algo,
        "scalarization": scalarization,
        "local_reward": local_reward,
        "split-objectives": agent_specific_reward,
        "weight sharing": weight_sharing,
        "baseline": baseline,
        "recurrent_policy": recurrent_policy,
        "device": device,
        "max_steps_per_episode": 1e3
    }

    a = DecEUPG(
        n_agents,
        [range(reward_dim)] * n_agents,
        env,
        [64],
        nash_scalarization,
        scalarization_weights=th.tensor([1 / 2**i for i in range(reward_dim)]),
        device=device,
        standardization=standardize,
        use_baseline=baseline,
        weight_sharing=weight_sharing,
        recurrent_policy=recurrent_policy,
        reward_conditioned=reward_conditioned,
    )

    logger = WandbLogger(
        f"AAMAS 2025-{env_name}",
        f'{config["Algorithm"]}-{seed}',
        config,
    )
    print(f"Training started on {device}")
    results_dir = f'results/{env_name}/{"split-objective(0,1)-" if agent_specific_reward else ""}{config["Algorithm"]}/{argv[1]}'

    a.train(
        timesteps,
        eval_env,
        eval_freq=50000,
        n_evals=100,
        log=True,
        logger=logger,
        save_results=True,
        results_dir=results_dir,
    )
    a.save_best_policy(
        f"{results_dir}/policy_weights{'.pt' if weight_sharing else '/'}"
    )
