from typing import List, Tuple
import torch
import time
from tqdm import tqdm
from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment

class Simulator :
    def __init__(self, environment: Environment, algorithms: List[Algorithm], show_alg_progress_bar : bool = False, show_time_progress_bar : bool = False) :
        self.environment = environment 
        self.algorithms = algorithms
        self.show_alg_progress_bar = show_alg_progress_bar
        self.show_time_progress_bar = show_time_progress_bar
        self.reset()

    def reset(self) :
        for alg in self.algorithms :
            alg.reset()

    def run(self, horizon) -> Tuple:
        states = [[] for _ in range(len(self.algorithms))]
        controls = [[] for _ in range(len(self.algorithms))]
        costs = [[] for _ in range(len(self.algorithms))]
        parameters = [[] for _ in range(len(self.algorithms))]
        computation_times = [-1 for _ in range(len(self.algorithms))]

        alg_enumeration = tqdm(enumerate(self.algorithms)) if self.show_alg_progress_bar else enumerate(self.algorithms)
        
        for alg_idx, alg in alg_enumeration :
            t_enumeration = tqdm(range(horizon)) if self.show_time_progress_bar else range(horizon)

            timer_start = time.time()
            for t in t_enumeration :
                if(t==0) :
                    state = self.environment.get_initial_state()
                else :
                    state = self.environment.transition(t-1, state, control)
                
                control = alg.get_control(t, state)
                cost = self.environment.cost(t,state,control)
                
                states[alg_idx].append(state)
                controls[alg_idx].append(control)
                costs[alg_idx].append(cost)
                parameters[alg_idx].append(alg.parameter)
                
            computation_times[alg_idx] = time.time() - timer_start

            states[alg_idx] = torch.stack(states[alg_idx])
            controls[alg_idx] = torch.stack(controls[alg_idx])
            costs[alg_idx] = torch.stack(costs[alg_idx])
            parameters[alg_idx] = torch.stack(parameters[alg_idx])

        return states, controls, costs, parameters, computation_times

        