import numpy as np

class ZeroSumSolver():
    def __init__(self, env, agents, session_n=1, session_len=1000, learning_type='by_sixes'):
        self.env = env
        self.agents = agents
        self.session_n = session_n
        self.learning_type = learning_type
        self.session_len = session_len
        self.total_rewards = []
        self.total_timesteps = []
        
    def go(self, timesteps):
        current_timesteps = 0
        
        while current_timesteps < timesteps:
            sessions = [self.get_session() for _ in range(self.session_n)]
            
            total_reward = np.mean([np.sum(session['rewards']) for session in sessions])
            self.total_rewards.append(total_reward)
            
            current_timesteps += np.sum([len(session['rewards']) for session in sessions])
            self.total_timesteps.append(current_timesteps)

            if self.learning_type == 'by_sessions':
                agents.fit(sessions)
                
        return None

    def get_session(self):
        session = {'states': [], 'u_actions': [], 'v_actions': [], 'rewards': [], 'dones': []}

        self.agents.reset()    
        state = self.env.reset()
        session['states'].append(state)

        for _ in range(self.session_len):
            u_action = self.agents.get_u_action(state)
            session['u_actions'].append(u_action)

            v_action = self.agents.get_v_action(state)
            session['v_actions'].append(v_action)

            next_state, reward, done, _ = self.env.step(u_action, v_action)
            session['rewards'].append(reward)
            session['dones'].append(done)

            if self.learning_type == 'by_sixes':
                self.agents.fit(state, u_action, v_action, reward, done, next_state)

            state = next_state
            session['states'].append(state)

            if done:
                break

        return session
