from src.problem_quadratic import QuadraticProblem
from src.predictor import Predictor
from collections import defaultdict
import numpy as np
import torch


OPTIMIZER_CLASSES_DICT = {'adam': torch.optim.Adam,
                          'sgd': torch.optim.SGD,
                          }


class VanillaRunner:

    def __init__(self,
                 problem: QuadraticProblem,
                 predictor: Predictor,
                 lr: float,
                 optimizer_class: str = 'adam'
                 ):

        self.current_epoch = 0
        self.problem = problem
        self.predictor = predictor
        # Optimizer parameters
        self.lr = lr
        assert optimizer_class in OPTIMIZER_CLASSES_DICT, 'Unknown optimizer class %s' % optimizer_class
        self.optimizer_class = optimizer_class
        self.optimizer = OPTIMIZER_CLASSES_DICT[optimizer_class](self.predictor.parameters, lr=lr)


    def _solve_controller_problem(self, param_values_batch, prediction_dict):
        # Solves the problem using predictions from the prediction dict
        # The values which are known are taken from the param_values_batch
        predicted_param_values_batch = dict(param_values_batch)
        for key, val in prediction_dict.items():
            predicted_param_values_batch[key] = val
        solution_dict = self.problem.solve(predicted_param_values_batch)
        return solution_dict

    def update_params(self, **kwargs):
        # Update weights for the losses and the learning rate
        for key, val in kwargs.items():
            assert hasattr(self, key), 'Runner %s does not have attribute %s!' % (self, val)
            setattr(self, key, val)
        if 'lr' in kwargs:
            self.optimizer = OPTIMIZER_CLASSES_DICT[self.optimizer_class](self.predictor.parameters, lr=self.lr)

    def compute_reward_loss(self, solution, param_values_batch):
        r_max = param_values_batch['r_max'].detach()
        reward = self.problem.get_reward(solution, param_values_batch)
        reward_loss = -reward
        return reward_loss, reward, r_max

    def copy_predictor_and_optimizer(self, optimizer_class_to_use=None):
        predictor = self.predictor.create_copy()
        optimizer_class_to_use = self.optimizer_class if optimizer_class_to_use is None else optimizer_class_to_use
        optimizer = OPTIMIZER_CLASSES_DICT[optimizer_class_to_use](predictor.parameters, lr=self.lr)
        if optimizer_class_to_use == self.optimizer_class:
            try:
                optimizer.load_state_dict(self.optimizer.state_dict())
            except:
                print('Could not load state dict')
        return predictor, optimizer

    def estimate_reward_gradient(self, param_values_batch):
        for key, val in param_values_batch.items():
            assert val.shape[0] == 1, 'Compute gradients only works for bs=1!'
        obs = self.problem.get_observation(param_values_batch).detach()
        grads_dict = {}
        # Reward gradient
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        reward_loss_before, _, _ = self.compute_reward_loss(solution_before, param_values_batch)
        reward_loss_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        for key, val in prediction_dict_after.items():
            grads_dict['rew-d_' + key] = val - prediction_dict_before[key]
        grads_dict['rew-d_solution'] = solution_after - solution_before
        return grads_dict

    def estimate_all_gradients(self, param_values_batch, **kwargs):
        grads_dict = self.estimate_reward_gradient(param_values_batch)
        return grads_dict

    def run_on_batch(self, param_values_batch, train=True, **kwargs):
        obs = self.problem.get_observation(param_values_batch).detach()
        if train:
            self.optimizer.zero_grad()
        prediction_dict = self.predictor.predict(obs)
        x_u = self.problem.get_x_u(prediction_dict, param_values_batch)
        solution_dict = self._solve_controller_problem(param_values_batch, prediction_dict)
        solution = solution_dict['solution']
        # Compute losses
        reward_loss, reward, r_max = self.compute_reward_loss(solution_dict['solution'], param_values_batch)
        reward_loss_true, reward_true, r_max_true = torch.tensor(reward_loss), torch.tensor(reward), torch.tensor(r_max)
        if 'solution_true' in solution_dict:
            reward_loss_true, reward_true, r_max_true = self.compute_reward_loss(solution_dict['solution_true'],
                                                                                 param_values_batch)
        # Total loss
        total_loss = reward_loss
        if train:
            total_loss.mean(dim=0).backward()
            self.optimizer.step()
        batch_shape = obs.shape[:-1]
        results_dict = {'reward_loss': reward_loss,
                        'proximity_loss':  torch.full(batch_shape, torch.nan),
                        'solution_distance': torch.full(batch_shape, torch.nan),
                        'projection_distance': torch.full(batch_shape, torch.nan),
                        'mse_loss': torch.full(batch_shape, torch.nan),
                        'total_loss': total_loss,
                        'x_u': x_u,
                        'reward': reward,
                        'reward_true': reward_true,
                        'regret': r_max - reward,
                        'regret_true': r_max_true - reward_true,
                        'r_max': r_max,
                        'solution': solution,
                        }
        results_dict.update(prediction_dict)
        return results_dict

    def run_epoch(self, batch_size=1, track_gradients=False, mode='train', normalize_losses=False, pretraining=False,
                  **kwargs):
        # Run training over the full train dataset
        indices = np.array(self.problem.dataset_indices[mode])
        if mode == 'train':
            np.random.shuffle(indices)
        n_batches = int(np.ceil(len(indices) / batch_size))
        epoch_results = defaultdict(list)
        for batch_ind in range(n_batches):
            inds_in_batch = indices[batch_ind * batch_size: (batch_ind + 1) * batch_size]
            param_values_batch, sampled_indices = self.problem.sample_parameters(batch_size=len(inds_in_batch),
                                                                                 indices=inds_in_batch,
                                                                                 mode=mode)
            results_dict = self.run_on_batch(param_values_batch, train=mode == 'train',
                                             normalize_losses=normalize_losses, pretraining=pretraining)
            for key, val in results_dict.items():
                for i in range(len(inds_in_batch)):
                    epoch_results[key].append(val[i].detach().numpy())
        if mode == 'train' and track_gradients:
            indices = np.array(self.problem.dataset_indices[mode])
            grads_dict = defaultdict(list)
            for ind_abs, ind in enumerate(indices):
                param_values_batch, sampled_indices = self.problem.sample_parameters(batch_size=1,
                                                                                     indices=[ind],
                                                                                     mode=mode)
                grads_dict_ind = self.estimate_all_gradients(param_values_batch)
                for key, val in grads_dict_ind.items():
                    grads_dict[key].append(val.detach().numpy())
            for key, val in grads_dict.items():
                epoch_results[key] = val
        return epoch_results
