import numpy as np
import torch
import torch.nn as nn

import copy
import random

import lfrl.torch.pytorch_util as ptu


class PBTTrainerWrapper:

    def __init__(self, model, optimizer, train_function, perturb_factors, identity_ind):
        self.model = model                          # the underlying torch model
        self.optimizer = optimizer                  # the underlying torch optimizer
        self.train_function = train_function        # actual function that trains the model
        self.perturb_factors = perturb_factors      # hyperparameter perturbation settings

        # keep track of who we exploited from; we can reconstruct this as a tree
        self.identity_ind = identity_ind
        self.identity = [identity_ind]

    def train_model(self):
        self.train_function()

    def exploit(self, better_trainer):
        better_model = better_trainer.model
        better_state_dict = copy.deepcopy(better_model.state_dict())
        self.model.load_state_dict(better_state_dict)

        better_optimizer = better_trainer.optimizer
        better_trainer_state_dict = copy.deepcopy(better_optimizer.state_dict())
        self.optimizer.load(better_trainer_state_dict)

        for hyperparameter in self.perturb_factors:
            h_dict = self.perturb_factors[hyperparameter]
            h_mode = h_dict['mode']
            h_options = h_dict['options']
            if h_mode == 'scale':
                scale_factor = np.random.choice(h_options)
                for param_group in self.optimizer.param_groups:
                    param_group[hyperparameter] *= scale_factor
            elif h_mode == 'discrete':
                new_value = random.choice(h_options)
                for param_group in self.optimizer.param_groups:
                    param_group[hyperparameter] = new_value
            else:
                raise NotImplementedError(
                    'PBT hyperparameter perturb mode for %s not recognized' % hyperparameter
                )

        self.identity = copy.deepcopy(better_trainer.identity) + [self.identity_ind]


class PBT(nn.Module):

    """
    A synchronous vanilla implementation of Population-Based Training, which has some
    differences with the asynchronous version.
    """

    def __init__(
            self,
            models,                               # list of the models in the PBT trainer
            model_trainers,                       # trainers for models
            fitness_function,                     # function used to determine best members of population
            exploit_function,                     # function used to exploit a model and trainer
            exploit_every,                        # how often to evaluate and exploit models
            bottom_pct=0.2,                       # bottom percent of models in ensemble to be replaced
            top_pct=0.2,                          # top models to exploit
    ):
        super().__init__()

        self.models = nn.ModuleList(models)
        self.model_trainers = model_trainers
        self.fitness_function = fitness_function
        self.exploit_function = exploit_function
        self.exploit_every = exploit_every
        self.bottom_pct = bottom_pct
        self.top_pct = top_pct

        self._num_steps_trained = 0

        # here we store the identity of every model ever birthed in the training process
        self.population_hist = [
            copy.deepcopy(trainer.identity) for trainer in self.model_trainers
        ]

    def train_models(self):
        for model_trainer in self.model_trainers:
            model_trainer.train_model()
        self._num_steps_trained += 1
        if self._num_steps_trained % self.exploit_every == 0:
            scores = self.evaluate_models()
            self.exploit_models(scores)

    def evaluate_models(self):
        evaluation_scores = []
        for model in self.models:
            evaluation_scores.append(self.exploit_function(model))
        return evaluation_scores

    def exploit_models(self, evaluation_scores):
        ranks = np.argsort(evaluation_scores)

        bi = int(self.bottom_pct * len(self.models))
        ti = int((1-self.top_pct) * len(self.models))

        better_model_trainers = [self.model_trainers[i] for i in ranks[ti:]]
        for i in range(bi+1):
            better_model_trainer = random.choice(better_model_trainers)
            self.model_trainers[ranks[i]].exploit(better_model_trainer)
            self.population_hist.append(copy.deepcopy(
                self.model_trainers[ranks[i]].identity
            ))
