import numpy as np
from envs.EnvironmentsWithFixedAgent import EnvironmentsWithFixedAgent as OneAgentEnv
from solvers.OneAgentSolver import OneAgentSolver

class AlternateSolver():
    def __init__(self, env, u_agent, v_agent, subalg_name, subtimesteps):
        self.env = env
        self.u_agent = u_agent
        self.v_agent = v_agent
        self.subtimesteps = subtimesteps
        self.subalg_name = subalg_name
        self.total_rewards = []
        self.total_timesteps = []
        
    def go(self, timesteps):
        while timesteps > 0:
            self.u_agent, self.v_agent = self.one_agent_go(self.u_agent, self.v_agent, 'v')
            self.v_agent, self.u_agent = self.one_agent_go(self.v_agent, self.u_agent, 'u')
            timesteps -= 2 * self.subtimesteps
            
    def one_agent_go(self, learning_agent, fixed_agent, fixed_agent_index):
        fixed_agent_env = OneAgentEnv(self.env, fixed_agent_index, fixed_agent)
        
        if self.subalg_name[:3] == 'SB3':
            learning_agent.set_env(fixed_agent_env)
            learning_agent.learn(self.subtimesteps)
        else:
            solver = OneAgentSolver(fixed_agent_env, learning_agent, self.subtimesteps)
            solver.go()
            
        self.total_rewards.extend(fixed_agent_env.total_rewards)
        self.total_timesteps.extend(fixed_agent_env.total_timesteps)
        return learning_agent, fixed_agent
