import numpy as np
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.path as mpa
import pickle 
import sys 
import pdb
from collections import namedtuple

from utils import * 
import envs



def run_experiment_episodic(
    env, agent, number_of_episodes,
    eval_only=False, max_ep_len=20,
    display_eps=None, decay_eps=None, decay_factor=0.5, respect_done=True
    ):
    """
    run an experiment
    """
    episodes = 0
    decayed = False 
    return_hist = []
    deltas = []
    trajectories = []
    if hasattr(agent, '_eval'): agent._eval = eval_only; 
    if hasattr(agent, 'r') and hasattr(env, 'r'): agent.r = env.r; 
    if hasattr(agent, 'reset'): agent.reset();

    try:
        action = agent.initial_action()
    except AttributeError:
        action = 0
    for i in range(1, number_of_episodes+1):

        if hasattr(agent, 'reset'): agent.reset();
        reward, discount, next_state, done = env.reset()
        if respect_done: agent._state = next_state; 
        if hasattr(agent, 'r') and hasattr(env, 'r'): agent.r = env.r; 
        elif hasattr(agent, 'w') and hasattr(env, 'r'): agent.w = env.r; 
        action = agent.step(reward, discount, next_state)
        z = reward
        traj = [env.obs_to_state_coords(next_state)]
        #episode_deltas = [delta]

        for t in range(1, max_ep_len+1):
            
             # effect of action in env
            reward, discount, next_state, done = env.step(action)
            if hasattr(agent, 'r') and hasattr(env, 'r'): agent.r = env.r; 
            elif hasattr(agent, 'w') and hasattr(env, 'r'): agent.w = env.r; 
            # agent takes next step
            action = agent.step(reward, discount, next_state)
            z += (discount ** t) * reward
            traj.append(env.obs_to_state_coords(next_state))
            #episode_deltas.append(delta)
            if done and respect_done: break; 

        return_hist.append(z)
        #deltas.append(np.mean(episode_deltas))

        # display progress 
        if display_eps is not None and i % display_eps == 0:
            flush_print(f"ep {i}/{number_of_episodes}: mean return = {np.mean(return_hist)}")

        trajectories.append(traj)


    results = {
      "return hist": return_hist,
      "trajectory": traj,
      "trajectories": trajectories,
      "deltas": deltas
    }
    if hasattr(agent, "state_values"): results["state values"] = agent.state_values; 
    if hasattr(agent, "q_values"): results["q values"] = agent.q_values; 
    if hasattr(agent, "SR"): results['SR'] = agent.SR; 


    return results 

