import numpy as np

class SimultaneousSolver():
    def __init__(self, env, u_agent, v_agent, session_n=1, 
                 session_len=10000, learning_type='by_fives'):
        self.env = env
        self.u_agent = u_agent
        self.v_agent = v_agent
        self.session_n = session_n
        self.session_len = session_len
        self.learning_type = learning_type
        self.total_rewards = []
        self.total_timesteps = []
        return None
        
    def go(self, timesteps):
        current_timesteps = 0
        
        while current_timesteps < timesteps:
            sessions = [self.get_session() for _ in range(self.session_n)]
            
            self.total_rewards.extend([np.sum(session['rewards']) for session in sessions])

            if self.learning_type == 'by_sessions':
                u_rewards = [-reward for reward in session['rewards']]
                u_sessions = [{'states': session['states'], 'actions': session['u_actions'], 
                               'dones': session['dones'], 'rewards': u_rewards} 
                              for session in sessions]
                u_agent.fit(u_sessions)
                v_sessions = [{'states': session['states'], 'actions': session['v_actions'], 
                               'dones': session['dones'], 'rewards': session['rewards']} 
                              for session in sessions]
                v_agent.fit(v_sessions)
            
            inner_timesteps = [len(session['rewards']) for session in sessions]
            self.total_timesteps.extend([current_timesteps + sum(inner_timesteps[:i+1]) 
                                         for i in range(len(inner_timesteps))])
            current_timesteps += sum(inner_timesteps)

        return None
        
    def get_session(self):
        session = {'states': [], 'u_actions': [], 'v_actions': [], 'rewards': [], 'dones': []}

        self.u_agent.reset()    
        self.v_agent.reset()
        state = self.env.reset()
        session['states'].append(state)

        for _ in range(self.session_len):
            u_action = self.u_agent.get_action(state)
            session['u_actions'].append(u_action)
            
            v_action = self.v_agent.get_action(state)
            session['v_actions'].append(v_action)

            next_state, reward, done, _ = self.env.step(u_action, v_action)
            session['rewards'].append(reward)
            session['dones'].append(done)

            if self.learning_type == 'by_fives':
                self.u_agent.fit(state, u_action, -reward, done, next_state)
                self.v_agent.fit(state, v_action, +reward, done, next_state)

            state = next_state
            session['states'].append(state)

            if done:
                break

        return session
    