import numpy as np
import scipy.signal
from modules.agents.Agent import Agent
from modules.utils.Log import Logger
from modules.utils.InputOutput import save_to_file
from modules.utils.Utils import printProgressBar
from typing import Tuple

def train(env, logger, training_parameters, agent, oracle, verbose=True, save=True, eval_callback=None) -> None:
    agent.is_training(True)

    # parameters
    total_episodes = training_parameters["number_episodes"]
    n_trajectories = training_parameters["batch_episode"]

    # save
    if save:
        save_to_file(logger.save_path(), 'oracle', oracle.save())
        save_to_file(logger.save_path(), 'training_parameters', training_parameters)

    # train
    best_agent = agent
    best_r = -np.inf
    episode = 0
    while episode < total_episodes:
        trajectories, r = run_policy(env, agent, n_trajectories, logger, eval_callback)
        if best_r < r:
            best_r = r
            best_agent = agent
            if save:
                # save new best agent
                save_to_file(logger.save_path(), 'agent', agent.save())

        episode += len(trajectories)

        logger.log({'_Episode': episode})
        A, rho_k = oracle.predict(episode=episode,
                                    policy=agent.policy(copy=False),
                                    trajectories=trajectories,
                                    logger=logger)

        # update policy
        states, actions, _ = unpack_trajectories(trajectories)
        agent.states_played = states
        agent.actions_played = actions
        agent.update_policy(advantage=A,
                            rho=rho_k,
                            logger=logger)

        # write logger results to file and stdout
        logger.write(display=verbose)

        # eval
        eval_callback(agent, True)

    # keep best agent
    agent = best_agent
    agent.is_training(False)


def run_episode(env, agent: Agent, animate=False, max_cnt=5000, eval_callback=None):
    state = env.reset()
    states, actions, rewards = [], [], []
    done = False
    cnt = 0
    while not done and cnt < max_cnt:
        cnt += 1
        if animate:
            env.render() 
        states.append(state)
        action = agent.take_action(state)
        actions.append(action)
        state, reward, done, _ = env.step(action)
        rewards.append(reward)
        if eval_callback is not None:
            eval_callback(agent)
    return (np.asarray(states), np.asarray(actions), np.array(rewards, dtype=np.float32))


def run_policy(env, agent: Agent, episodes: int, logger: Logger, eval_callback=None) -> Tuple[list, np.ndarray]:
    trajectories = []
    cumrewards = []
    for e in range(episodes):
        states, actions, rewards = run_episode(env, agent, eval_callback=eval_callback)
        cumrewards.append(np.sum(rewards))
        trajectories.append(pack_trajectory(states, actions, rewards))
        printProgressBar(e+1, episodes, prefix='Policy simulation: ', suffix='', decimals=1, length=100, fill='█', printEnd="\r")
    
    batch_size = int(len(cumrewards) ** 0.5)
    avg_rewards = [np.mean(cumrewards[i:i+batch_size]) for i in np.arange(0,len(cumrewards),batch_size)]

    r = np.mean(avg_rewards)

    logger.log({"_AvgRewardSum": r,
                "_StdRewardSum": np.std(avg_rewards),
                "_MinRewardSum": np.min(cumrewards),
                "_MaxRewardSum": np.max(cumrewards)})
    return trajectories, r

def discount(x, gamma):
    """ Calculate discounted forward sum of a sequence at each point """
    return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1]

def unpack_trajectories(trajectories):
    states = np.concatenate([t["states"] for t in trajectories])
    actions = np.concatenate([t["actions"] for t in trajectories])
    rewards = np.concatenate([t["rewards"] for t in trajectories])

    return states, actions, rewards

def pack_trajectory(states, actions, rewards):
    return {"states": states, "actions": actions, "rewards": rewards}