import os
import numpy as np
import torch


class PriMORL:
    def __init__(
            self,
            policy,
            dynamics_model,
            offline_buffer,
            model_buffer,
            reward_penalty_coef,
            rollout_length,
            batch_size,
            real_ratio,
            logger,
            model_batch_size,
            rollout_batch_size=50000,
            rollout_mini_batch_size=1000,
            model_retain_epochs=1,
            num_env_steps_per_epoch=1000,
            max_epoch=100000,
            max_model_update_epochs_to_improve=15,
            max_model_train_iterations="None",
            hold_out_ratio=0.1,
            poisson_q=0.01,
            model_rounds=100,
            load_model_name=None,
            continue_training=False,
            **kwargs
    ):
        self.policy = policy
        self.dynamics_model = dynamics_model
        self.offline_buffer = offline_buffer
        self.model_buffer = model_buffer
        self._reward_penalty_coef = reward_penalty_coef
        self._rollout_length = rollout_length
        self._rollout_batch_size = rollout_batch_size
        self._batch_size = batch_size
        self._real_ratio = real_ratio
        self.model_batch_size = model_batch_size
        print(f"Model batch size: {self.model_batch_size}")
        self.rollout_mini_batch_size = rollout_mini_batch_size
        self.model_retain_epochs = model_retain_epochs
        self.num_env_steps_per_epoch = num_env_steps_per_epoch
        self.max_model_update_epochs_to_improve = max_model_update_epochs_to_improve
        if max_model_train_iterations == "None":
            self.max_model_train_iterations = np.inf
        else:
            self.max_model_train_iterations = max_model_train_iterations
        self.max_epoch = max_epoch
        self.hold_out_ratio = hold_out_ratio
        self.model_tot_train_timesteps = 0
        self.logger = logger

        self.poisson_q = poisson_q
        self.model_rounds = model_rounds
        self.load_model = True if load_model_name is not None else False
        self.continue_training = continue_training

    def _sample_initial_transitions(self):
        return self.offline_buffer.sample(self._rollout_batch_size)

    def rollout_transitions(self):
        init_transitions = self._sample_initial_transitions()
        # rollout
        observations = init_transitions["observations"]
        for _ in range(self._rollout_length):
            actions = self.policy.sample_action(observations)
            next_observations, rewards, terminals, infos = self.dynamics_model.predict(observations, actions)
            self.model_buffer.add_batch(observations, next_observations, actions, rewards, terminals)
            nonterm_mask = (~terminals).flatten()
            if nonterm_mask.sum() == 0:
                break
            observations = next_observations[nonterm_mask]

    def learn_dynamics(self):
        # get train and eval data
        max_sample_size = self.offline_buffer.get_size
        env_data = self.offline_buffer.sample_all()
        episode_idx_list = np.unique(env_data['episode_idx'])
        np.random.shuffle(episode_idx_list)
        num_episodes = len(episode_idx_list)
        num_train_episodes = int(num_episodes * (1.0 - self.hold_out_ratio))
        train_data_episode_idx = episode_idx_list[:num_train_episodes]
        train_data_indices = np.isin(env_data['episode_idx'].flatten(), train_data_episode_idx)
        train_data_indices = np.argwhere(train_data_indices).flatten()
        eval_data_episode_idx = episode_idx_list[num_train_episodes:]
        eval_data_indices = np.isin(env_data['episode_idx'].flatten(), eval_data_episode_idx)
        eval_data_indices = np.argwhere(eval_data_indices).flatten()
        train_data, eval_data = {}, {}
        for key in env_data.keys():
            train_data[key] = env_data[key][train_data_indices]
            eval_data[key] = env_data[key][eval_data_indices]
        self.dynamics_model.reset_normalizers()
        self.dynamics_model.update_normalizer(train_data['observations'], train_data['actions'])

        # train model
        model_train_iters = 0
        model_train_epochs = 0
        num_epochs_since_prev_best = 0
        break_training = False
        self.dynamics_model.reset_best_snapshots()

        # init eval_mse_losses
        self.logger.print("Start training dynamics")
        eval_mse_losses, _ = self.dynamics_model.eval_data(eval_data, update_elite_models=False)
        self.logger.record("loss/model_eval_mse_loss", eval_mse_losses.mean(), self.model_tot_train_timesteps)
        updated = self.dynamics_model.update_best_snapshots(eval_mse_losses)
        print(f"Max update epochs to improve: {self.max_model_update_epochs_to_improve}")
        total_rounds = 0
        if self.load_model and not self.continue_training:
            break_training = True
            model_log_infos = {}
        while not break_training:
            for step in range(int(1 / self.poisson_q)):

                episodes = np.unique(train_data['episode_idx'])
                episodes_indices = train_data['episode_idx'].flatten()
                # Select users
                rand_array = np.random.rand(len(episodes))
                selected_episodes = episodes[rand_array < self.poisson_q]
                if len(selected_episodes) == 0:
                    continue
                keys = list(train_data.keys())
                num_data = len(train_data[keys[0]])
                indices = np.arange(num_data)

                batch_list = []
                ep_weights = []
                for current_ep_id in selected_episodes:
                    current_ep_indices = indices[episodes_indices == current_ep_id]
                    current_batch_indices = current_ep_indices
                    current_batch_data = {}
                    for key in keys[:-1]:  # Do not take episode_idx field
                        current_batch_data[key] = train_data[key][current_batch_indices]
                    batch_list.append(current_batch_data)
                    ep_weights.append(1)

                # Compute episode weights
                ep_weights = np.array(ep_weights, dtype=np.float64)
                weight_multiplier = 1 / (len(np.unique(train_data['episode_idx'])) * self.poisson_q)

                model_log_infos = self.dynamics_model.update(batch_list, ep_weights, weight_multiplier)
                model_train_iters += 1
                self.model_tot_train_timesteps += 1

            eval_mse_losses, _ = self.dynamics_model.eval_data(eval_data, update_elite_models=False)
            self.logger.record("loss/model_eval_mse_loss", eval_mse_losses.mean(), self.model_tot_train_timesteps)
            updated = self.dynamics_model.update_best_snapshots(eval_mse_losses)
            num_epochs_since_prev_best += 1
            total_rounds += 1
            print(num_epochs_since_prev_best)
            if updated:
                model_train_epochs += num_epochs_since_prev_best
                num_epochs_since_prev_best = 0
            if num_epochs_since_prev_best >= self.max_model_update_epochs_to_improve or total_rounds > self.model_rounds:
                break

        self.dynamics_model.load_best_snapshots()
        f = open(os.path.join(self.logger.log_path, 'nb_rounds.txt'), 'w')
        f.write(f'Nb. rounds: {total_rounds}')

        # evaluate data to update the elite models
        self.dynamics_model.eval_data(eval_data, update_elite_models=True)
        model_log_infos['misc/norm_obs_mean'] = torch.mean(torch.Tensor(self.dynamics_model.obs_normalizer.mean)).item()
        model_log_infos['misc/norm_obs_var'] = torch.mean(torch.Tensor(self.dynamics_model.obs_normalizer.var)).item()
        model_log_infos['misc/norm_act_mean'] = torch.mean(torch.Tensor(self.dynamics_model.act_normalizer.mean)).item()
        model_log_infos['misc/norm_act_var'] = torch.mean(torch.Tensor(self.dynamics_model.act_normalizer.var)).item()
        model_log_infos['misc/model_train_epochs'] = model_train_epochs
        model_log_infos['misc/model_train_train_steps'] = model_train_iters
        return model_log_infos


    def learn_policy(self):
        real_sample_size = int(self._batch_size * self._real_ratio)
        fake_sample_size = self._batch_size - real_sample_size
        real_batch = self.offline_buffer.sample(batch_size=real_sample_size)
        fake_batch = self.model_buffer.sample(batch_size=fake_sample_size)
        data = {
            "observations": np.concatenate([real_batch["observations"], fake_batch["observations"]], axis=0),
            "actions": np.concatenate([real_batch["actions"], fake_batch["actions"]], axis=0),
            "next_observations": np.concatenate([real_batch["next_observations"], fake_batch["next_observations"]],
                                                axis=0),
            "terminals": np.concatenate([real_batch["terminals"], fake_batch["terminals"]], axis=0),
            "rewards": np.concatenate([real_batch["rewards"], fake_batch["rewards"]], axis=0)
        }
        loss = self.policy.learn(data)
        return loss

    def save_dynamics_model(self, info):
        self.dynamics_model.save_model(info)
