import fire
import gym
import mo_gymnasium as mo_gym
import numpy as np

import wandb as wb
from gpi.successor_features.gpi import GPI
from gpi.successor_features.tabular_sf import SF
from gpi.utils.wrappers import RandomAction
from gpi.utils.eval import policy_evaluation_mo
from gpi.utils.utils import equally_spaced_weights


def run(iter: int = 3, stochastic: bool = True, num_seeds: int = 20, noise=None):

    if stochastic:
        eval_env = mo_gym.LinearReward(RandomAction(mo_gym.make("four-room-v0")))
    else:
        eval_env = mo_gym.LinearReward(mo_gym.make("four-room-v0"))
    
    test_tasks = equally_spaced_weights(dim=3, num_weights=32)
    reps = 10 if stochastic else 1
    scores = {}
    for seed in range(1, num_seeds+1):
        gpi_agent = GPI.load(f"weights/gpi_fourroom_stochastic={stochastic}_{seed}_iter={iter}")

        if noise is not None:
            for pi in gpi_agent.policies:
                pi.add_noise(noise)

        gpi_agent.mpc = False
        for h in [0, 1, 2, 4, 6, 8, 10]:
            print(f"Seed {seed}, h={h}")
            gpi_agent.h_step = h
            score = np.array([policy_evaluation_mo(gpi_agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
            if seed == 1:
                scores[f"{h}-GPI"] = score
            else:
                scores[f"{h}-GPI"] = np.vstack((scores[f"{h}-GPI"], score))

        gpi_agent.mpc = True
        for h in [10]:
            print(f"Seed {seed}, h={h}, mpc")
            gpi_agent.h_step = h
            score = np.array([policy_evaluation_mo(gpi_agent, eval_env, w, rep=reps, return_scalarized_value=True) for w in test_tasks])
            if seed == 1:
                scores[f"{h}-MPC"] = score
            else:
                scores[f"{h}-MPC"] = np.vstack((scores[f"{h}-MPC"], score))

    for h in [0, 1, 2, 4, 6, 8, 10]:
        np.save(f"results/gpi_fourroom_stochastic={stochastic}_noise={noise}_iter={iter}_h={h}", scores[f"{h}-GPI"])

    for h in [10]:
        np.save(f"results/gpi_fourroom_stochastic={stochastic}_iter={iter}_h={h}_mpc", scores[f"{h}-MPC"])

if __name__ == "__main__":
    fire.Fire(run)
