import numpy as np
from tqdm import trange
from option_utils import (
    execute_macro_action_with_trajectory, 
)


def train_agent(
    env, 
    agent, 
    num_episodes: int, 
    evaluation_freq: int = 1, 
    max_steps_per_episode: int = 1000,
):
    reward_per_episode = []
    steps_per_episode = []
    states_visited = []
    with trange(num_episodes) as pbar:
        for episode in pbar:
            state = env.reset()
            done = False
            
            states_visited_episode = [state]
            
            train_steps = 0
            while not done and train_steps < max_steps_per_episode:
                action = agent.select_action(state)
                next_state, reward, done, _ = env.step(state, action)
                agent.update(state, action, reward, next_state, done)
                state = next_state
                train_steps += 1
                states_visited_episode.append(state)
            
            states_visited.append(states_visited_episode)

            if (episode + 1) % evaluation_freq == 0 or episode == 0 or episode == num_episodes - 1:
                total_reward, steps = evaluate_agent(env, agent, max_steps=max_steps_per_episode)
                pbar.set_description(
                    f"Episode {episode+1}: Eval Reward={total_reward:.2f}, Steps={steps}"
                )
                reward_per_episode.append(total_reward)
                steps_per_episode.append(steps)

    return agent, reward_per_episode, steps_per_episode, states_visited


def evaluate_agent(
    env, 
    agent, 
    max_steps: int = 1000, 
):
    state = env.reset()
    done = False
    total_reward = 0.0
    steps = 0
    
    while not done and steps < max_steps:
        action = agent.select_greedy_action(state)
        next_state, reward, done, _ = env.step(state, action)
        total_reward += reward
        state = next_state
        steps += 1
        
    return total_reward, steps

def train_eigenoption_agent(
    env,
    agent,
    options,
    num_episodes=100,
    evaluation_freq=1,
    max_steps_per_episode=1000,
    time_penalty=0.0,
    reward_multiplier=1.0,
):
    reward_per_episode = []
    steps_per_episode = []

    with trange(num_episodes) as pbar:
        for ep in pbar:
            s = env.reset()
            done = False
            steps_used = 0

            while (not done) and (steps_used < max_steps_per_episode):
                a = agent.select_action(s)

                s2, total_r, k, done, traj = execute_macro_action_with_trajectory(
                    env=env,
                    state=s,
                    action=a,
                    options=options,
                    num_primitive_actions=agent.num_primitive_actions,
                    max_steps_remaining=max_steps_per_episode - steps_used,
                )

                total_r = (total_r - time_penalty * k) * reward_multiplier

                agent.update_smdp(
                    start_state=s,
                    action=a,
                    total_reward=total_r,
                    end_state=s2,
                    done=done,
                    k=k,
                    trajectory_states=traj,
                )
                
                s = s2
                steps_used += k

            if (evaluation_freq > 0) and ((ep + 1) % evaluation_freq == 0 or ep == 0):
                eval_r, eval_steps = evaluate_eigenoption_agent(
                    env=env,
                    agent=agent,
                    options=options,
                    max_steps_per_episode=max_steps_per_episode,
                )
                reward_per_episode.append(eval_r)
                steps_per_episode.append(eval_steps)
                pbar.set_description(f"Ep {ep+1}: evalR={eval_r:.2f}, evalSteps={eval_steps}")

    return agent, reward_per_episode, steps_per_episode


def evaluate_eigenoption_agent(
    env,
    agent,
    options,
    max_steps_per_episode=1000,
):
    s = env.reset()
    done = False
    total_reward = 0.0
    steps_used = 0

    while (not done) and (steps_used < max_steps_per_episode):
        a = agent.select_greedy_action(s)

        s2, total_r, k, done, _ = execute_macro_action_with_trajectory(
            env=env,
            state=s,
            action=a,
            options=options,
            num_primitive_actions=agent.num_primitive_actions,
            max_steps_remaining=max_steps_per_episode - steps_used,
        )

        total_reward += total_r
        steps_used += k
        s = s2

    return total_reward, steps_used