import numpy as np

class OneAgentSolver():
    
    def __init__(self, env, agent, timesteps, session_n=1, session_len=10000, learning_type='by_fives'):
        self.env = env
        self.agent = agent
        self.timesteps = timesteps
        self.session_n = session_n
        self.learning_type = learning_type
        self.session_len = session_len
        self.total_rewards = []
    
    def go(self):
        remaining_timesteps = self.timesteps
        while remaining_timesteps > 0:
            sessions = [self.get_session() for _ in range(self.session_n)]
            
            total_reward = np.mean([np.sum(session['rewards']) for session in sessions])
            self.total_rewards.append(total_reward)

            if self.learning_type == 'by_sessions':
                self.agent.fit(sessions)
                
            remaining_timesteps -= np.sum([len(session['rewards']) for session in sessions])

        return None

    def get_session(self):
        session = {'states': [], 'actions': [], 'rewards': [], 'dones': []}

        self.agent.reset()    
        state = self.env.reset()
        session['states'].append(state)

        for _ in range(self.session_len):
            action = self.agent.get_action(state)
            session['actions'].append(action)

            next_state, reward, done, _ = self.env.step(action)
            session['rewards'].append(reward)
            session['dones'].append(done)

            if self.learning_type == 'by_fives':
                self.agent.fit(state, action, reward, done, next_state)

            state = next_state
            session['states'].append(state)

            if done:
                break

        return session
