import torch
import time

from mpi4py import MPI
from core.logger import Logger
from core.qd_strategies.map_elites_rl import MapElitesRL

comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
rank = comm.Get_rank()


class Trainer:
    def __init__(self, config):
        self.config = config
        torch.set_num_threads(1)
        self.logger = Logger(self, config)

        # Evolution strategy name
        self.evolution_strategy_name = config.evolution_strategy_name

        # Mixing strategy
        self.qd_strategy = MapElitesRL(config)

        self.generation_duration = 0

    def train(self):
        """Run the main training loop"""
        while self.qd_strategy.total_steps < self.config.xp_params.max_steps:
            start = time.time()

            # Train current generation
            self.qd_strategy.train_one_generation()
            self.logger.log()
            if rank == 0:
                # Update parameters of the ES
                self.qd_strategy.update()
                # Draw a new generation
                self.qd_strategy.draw_generation()
            # self.logger.save_population_params()

            end = time.time()
            self.generation_duration = end-start
