import fire
import gymnasium as gym
import mo_gymnasium as mo_gym
import numpy as np
# from gymnasium.wrappers.record_video import RecordVideo

import wandb as wb
from gpi.successor_features.usfa_jax import USFA
from gpi.utils.utils import extrema_weights, seed_everything


def best_vector(values, w):
    max_v = values[0]
    for i in range(1, len(values)):
        if values[i] @ w > max_v @ w:
            max_v = values[i]
    return max_v


def main(timesteps_per_iter: int = 50000, seed: int = 0, log: bool = True):

    seed_everything(seed)

    def make_env(record_video=False):
        env = mo_gym.make("mo-reacher-v4")
        env = mo_gym.LinearReward(env)
        #if record_video:
        #    env = RecordVideo(env, "videos/reacher/", episode_trigger=lambda e: e % 1 == 0)
        return env

    env = make_env()
    eval_env = make_env()

    def model_training_schedule(timestep):
        if timestep < 100000:
            return 250
        else:
            return 250

    agent = USFA(
        env,
        num_nets=1,
        max_grad_norm=None,
        learning_rate=3e-4,
        gamma=0.9,
        batch_size=128,
        net_arch=[256, 256, 256, 256],
        buffer_size=250000,
        initial_epsilon=0.05,
        final_epsilon=0.05,
        epsilon_decay_steps=1,
        learning_starts=100,
        alpha_per=0.6,
        min_priority=0.01,       
        per=True,
        use_gpi=True,
        h_step=1,
        gpi_type='gpi',
        gradient_updates=10,
        target_net_update_freq=200,
        tau=1,
        dyna=True,
        dynamics_uncertainty_threshold=1.5,
        dynamics_net_arch=[200, 200, 200],
        dynamics_buffer_size=int(1e5),
        dynamics_rollout_batch_size=25000,
        dynamics_train_freq=model_training_schedule,
        dynamics_rollout_freq=250,
        dynamics_rollout_starts=5000,
        dynamics_rollout_len=1,
        real_ratio=0.5,
        log=log,
        project_name="h-GPI",
        experiment_name=f"USFA",
        seed=seed,
    )

    max_iter = 4
    for iter in range(1, max_iter + 1):
        w = np.zeros(env.reward_dim)
        w[iter - 1] = 1.0
            
        print('Next weight vector:', w)
        M = extrema_weights(env.reward_dim)

        agent.learn(
            total_timesteps=timesteps_per_iter,
            w=w,
            M=M,
            change_w_each_episode=True, # algo !-= 'ols',
            eval_env=eval_env,
            eval_freq=1000,
            reset_num_timesteps=False,
            reset_learning_starts=False
        )

        agent.save(filename=f'usfa-reacher-{seed}')

    agent.close_wandb()


if __name__ == "__main__":
    fire.Fire(main)
