from timeit import default_timer as timer

import numpy as np

from all2.logging import ExperimentLogger
from all2.environments.gym import GymEnvironment

from .experiment import Experiment
import matplotlib.pyplot as plt
import copy


class SingleEnvExperiment(Experiment):
    """An Experiment object for training and testing agents that interact with one environment at a time."""

    def __init__(
        self,
        preset,
        env,
        name=None,
        train_steps=float("inf"),
        logdir="runs",
        quiet=False,
        render=False,
        save_freq=100,
        verbose=True,
    ):
        self._name = name if name is not None else preset.name
        super().__init__(
            self._make_logger(logdir, self._name, env.name, verbose), quiet
        )
        self._logdir = logdir
        self._preset = preset
        self._agent = self._preset.agent(logger=self._logger, train_steps=train_steps)
        self._env = env
        self.action_meanings = env.env.unwrapped.get_action_meanings()
        self._render = render
        self._frame = 1
        self._episode = 1
        self._save_freq = save_freq

        if render:
            self._env.render(mode="human")

    @property
    def frame(self):
        return self._frame

    @property
    def episode(self):
        return self._episode

    def train(self, frames=np.inf, episodes=np.inf):
        while not self._done(frames, episodes):
            self._run_training_episode()
        if len(self._returns100) > 0:
            self._logger.add_summary("returns100", self._returns100)

    def test(self, episodes=100):
        test_agent = self._preset.test_agent()
        returns = []
        episode_lengths = []
        # energies = []
        for episode in range(episodes):
            episode_return, episode_length = self._run_test_episode(test_agent)
            returns.append(episode_return)
            episode_lengths.append(episode_length)
            # energies.append(energy)
            self._log_test_episode(episode, episode_return, episode_length)
            # self._log_test_episode_w_energy(episode, episode_return, episode_length, energy)
        # self._log_test_w_energy(returns, episode_lengths, energies)
        self._log_test(returns, episode_lengths)
        return returns
    
    def visualize(self):
        test_agent = self._preset.test_agent()
        returns = []
        episode_lengths = []
        episode_return, episode_length = self._visualize_test_episode(test_agent)
        self._log_test_episode(0, episode_return, episode_length)
        self._log_test(returns, episode_lengths)
        return returns

    def _run_training_episode(self):
        # initialize timer
        start_time = timer()
        start_frame = self._frame

        # initialize the episode
        # just 1 state at this point
        state = self._env.reset()
        
        checkstate = state
        
        # this is now n_update_actions actions
        action = self._agent.act(state)
        returns = 0
        episode_length = 0

        # loop until the episode is finished
        while not checkstate.done:
            if self._render:
                self._env.render()
            # save current state
            if isinstance(self._env, GymEnvironment):
                snapshot = copy.deepcopy(state)
            else:
                snapshot = self._env.env.ale.cloneState()
            next_states = []
            for i, single_action in enumerate(action):
                # print(i, single_action)
                next_state = self._env.step(single_action)
                if i == 0:
                    if isinstance(self._env, GymEnvironment):
                        snapshot_next = copy.deepcopy(next_state)
                    else:
                        snapshot_next = self._env.env.ale.cloneState()
                next_states.append(next_state)
                if isinstance(self._env, GymEnvironment):
                    self._env.env.state = snapshot
                else:
                    self._env.env.ale.restoreState(snapshot)
            if isinstance(self._env, GymEnvironment):
                self._env.env.state = snapshot_next
            else:
                self._env.env.ale.restoreState(snapshot_next)
            
            checkstate = next_states[0]
            action = self._agent.act(next_states)
            returns += checkstate.reward
            episode_length += 1
            self._frame += self._agent.cur_n_update_actions

        # stop the timer
        end_time = timer()
        fps = (self._frame - start_frame) / (end_time - start_time)

        # log the results
        self._log_training_episode(returns, episode_length, fps)
        self._save_model()

        # update experiment state
        self._episode += 1


    # Original without energy
    def _run_test_episode(self, test_agent):
        # initialize the episode
        state = self._env.reset()
        action = test_agent.act(state)
        returns = 0
        episode_length = 0

        # loop until the episode is finished
        while not state.done:
            if self._render:
                self._env.render()
            state = self._env.step(action)
            action = test_agent.act(state)
            # print(state.reward)
            returns += state.reward
            episode_length += 1

        return returns, episode_length, test_agent.total_energy

    # def _run_test_episode(self, test_agent):
    #     # initialize the episode
    #     state = self._env.reset()
    #     action = test_agent.act(state)
    #     returns = 0
    #     episode_length = 0
    #     energy = 0

    #     # loop until the episode is finished
    #     while not state.done:
    #         if self._render:
    #             self._env.render()
    #         state = self._env.step(action)
    #         action = test_agent.act(state)
    #         meaning = self.action_meanings[action]
    #         if meaning in ["FIRE", "UP", "RIGHT", "LEFT", "DOWN"]:
    #             energy +=1
    #         elif meaning in ["UPRIGHT", "UPLEFT", "DOWNRIGHT", "DOWNLEFT", "UPFIRE",
    #                         "DOWNRIGHT", "DOWNLEFT", "UPFIRE", "RIGHTFIRE", "LEFTFIRE",
    #                         "DOWNFIRE"]:
    #             energy += 2
    #         elif meaning in ["UPRIGHTFIRE", "UPLEFTFIRE", "DOWNRIGHTFIRE", "DOWNLEFTFIRE"]:
    #             energy += 3
    #         returns += state.reward
    #         episode_length += 1

    #     return returns, episode_length, energy
    
    def _visualize_test_episode(self, test_agent):
        # initialize the episode
        action_meanings = self._env.env.unwrapped.get_action_meanings()
        state = self._env.reset()
        action = test_agent.act(state)
        returns = 0
        episode_length = 0

        # loop until the episode is finished
        while not state.done:
            frame = self._env.render()
            state = self._env.step(action)
            action = test_agent.act(state)
            print(action_meanings)
            plt.imsave(f"frame.png", frame, dpi=400)
            print("Frame': ", episode_length)
            print("Played action: ", action_meanings[action])
            a = input("PRESS ANY BUTTON TO CONTINUE ")
            returns += state.reward
            episode_length += 1

        return returns, episode_length

    def _done(self, frames, episodes):
        return self._frame > frames or self._episode > episodes

    def _make_logger(self, logdir, agent_name, env_name, verbose):
        return ExperimentLogger(
            self, agent_name, env_name, verbose=verbose, logdir=logdir
        )

    def _save_model(self):
        if self._save_freq != float("inf") and self._episode % self._save_freq == 0:
            self.save()
