import numpy as np
import torch
import os
from unstable_baselines.common import util, functional
from unstable_baselines.common.networks import get_optimizer
from unstable_baselines.model_based_rl.common.models import EnsembleModel
from operator import itemgetter
from unstable_baselines.common.normalizer import StandardNormalizer
from copy import deepcopy

class PriorBuffer:
    def __init__(self, capacity, input_dim, target_dim, decay=0.99, device="cpu", dtype=torch.float32):
        self.capacity = capacity
        self.input_dim = input_dim
        self.target_dim = target_dim
        self.device = device
        self.dtype = dtype
        self.decay = decay

        self.size = 0
        self.input_buf = torch.zeros((capacity, input_dim), dtype=dtype, device=device)
        self.target_buf = torch.zeros((capacity, target_dim), dtype=dtype, device=device)
        self.w_buf = torch.zeros((capacity, 1), dtype=dtype, device=device)

    @torch.no_grad()
    def add(self, temp):
        model_input, groundtruths, w = temp
        if not isinstance(model_input, torch.Tensor):
            model_input = torch.as_tensor(model_input, dtype=self.dtype, device=self.device)
        if not isinstance(groundtruths, torch.Tensor):
            groundtruths = torch.as_tensor(groundtruths, dtype=self.dtype, device=self.device)
        if w is None:
            w = torch.ones((model_input.size(0), 1), dtype=self.dtype, device=self.device)
        elif not isinstance(w, torch.Tensor):
            w = torch.as_tensor(w, dtype=self.dtype, device=self.device)
        w = w.view(-1, 1)

        batch = model_input.size(0)
        for i in range(batch):
            if self.size < self.capacity:
                idx = self.size
                self.size += 1
            else:
                self.w_buf = self.w_buf * self.decay
                min_w, min_idx = torch.min(self.w_buf[:self.size], dim=0)
                if w[i] <= min_w:
                    continue
                idx = min_idx.item()

            self.input_buf[idx] = model_input[i].to(self.device)
            self.target_buf[idx] = groundtruths[i].to(self.device)
            self.w_buf[idx] = w[i]

    def get(self, sort_by_weight=True, descending=True):
        valid_idx = slice(0, self.size)
        model_input = self.input_buf[valid_idx]
        groundtruths = self.target_buf[valid_idx]
        w = self.w_buf[valid_idx]

        if sort_by_weight:
            sorted_w, sorted_idx = torch.sort(w.squeeze(), descending=descending)
            model_input = model_input[sorted_idx]
            groundtruths = groundtruths[sorted_idx]
            w = w[sorted_idx]
        return model_input, groundtruths, w

    def sample(self, batch_size):
        model_input, groundtruths, w = self.get()
        if self.size == 0:
            empty_input = torch.zeros((0, self.input_dim), dtype=self.dtype, device=self.device)
            empty_target = torch.zeros((0, self.target_dim), dtype=self.dtype, device=self.device)
            return empty_input, empty_target

        n = model_input.shape[0]
        if n > batch_size:
            b = 3
            m = batch_size
            ranks = np.arange(1, n + 1)
            weights = ((n - ranks + 1) / n) ** (b - 1)
            p = weights / weights.sum()
            idx = np.random.choice(n, size=m, replace=False, p=p)
        else:
            idx = np.arange(n)
        return model_input[idx], groundtruths[idx]

    def reset(self):
        self.size = 0
        self.input_buf = torch.zeros((self.capacity, self.input_dim), dtype=self.dtype, device=self.device)
        self.target_buf = torch.zeros((self.capacity, self.target_dim), dtype=self.dtype, device=self.device)
        self.w_buf = torch.zeros((self.capacity, 1), dtype=self.dtype, device=self.device)


class TransitionModel:
    def __init__(self, 
            obs_space, 
            action_space, 
            env_name,
            prior_ratio,
            holdout_ratio: float, 
            inc_var_loss: bool, 
            use_weight_decay: bool,
            **kwargs):
        
        obs_dim = obs_space.shape[0]
        action_dim = action_space.shape[0]

        self.model = EnsembleModel(obs_dim=obs_dim, action_dim=action_dim, device=util.device, **kwargs['model'])
        self.env_name = env_name
        self.prior_ratio = prior_ratio
        # print("params", type(self.model.parameters()))
        # for i, p in enumerate(self.model.parameters()):
        #     print(i, p.shape)
        # exit(0)

        self.model_optimizer = get_optimizer(optimizer_class=kwargs['optimizer_class'], network=self.model, learning_rate=kwargs['learning_rate'] )
        self.networks = {
            "model": self.model
        }
        self.obs_space = obs_space
        self.holdout_ratio = holdout_ratio
        self.inc_var_loss = inc_var_loss
        self.use_weight_decay = use_weight_decay
        self.obs_normalizer = StandardNormalizer()
        self.act_normalizer = StandardNormalizer()
        self.model_train_timesteps = 0
        input_dim = obs_dim + action_dim
        target_dim = obs_dim + 1
        self.priority_buffer = PriorBuffer(
            capacity=2000,
            input_dim=input_dim,
            target_dim=target_dim,
            device=util.device,
        )


    def _termination_fn(self, env_name, obs, act, next_obs):
        assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
        if env_name in ["Hopper-v2", "Hopper-v3"]:
            height = next_obs[:, 0]
            angle = next_obs[:, 1]
            not_done = np.isfinite(next_obs).all(axis=-1) \
                       * np.abs(next_obs[:, 1:] < 100).all(axis=-1) \
                       * (height > .7) \
                       * (np.abs(angle) < .2)

            done = ~not_done
        elif env_name in ["Walker2d-v2", "Walker2d-v3"]:
            height = next_obs[:, 0]
            angle = next_obs[:, 1]
            not_done = (height > 0.8) \
                       * (height < 2.0) \
                       * (angle > -1.0) \
                       * (angle < 1.0)
            done = ~not_done
            return done
        elif env_name == "HumanoidTruncatedObs-v2":
            z = next_obs[:,0]
            done = (z < 1.0) + (z > 2.0)
        elif env_name == "InvertedDoublePendulum-v2":
            sin1, cos1 = next_obs[:,1], next_obs[:,3]
            sin2, cos2 = next_obs[:,2], next_obs[:,4]
            theta_1 = np.arctan2(sin1, cos1)
            theta_2 = np.arctan2(sin2, cos2)
            y = 0.6 * (cos1 + np.cos(theta_1 + theta_2))

            done = y <= 1
            return done
        elif env_name == "InvertedPendulum-v2":
            notdone = np.isfinite(next_obs).all(axis=-1) \
                    * (np.abs(next_obs[:,1]) <= .2)
            done = ~notdone
        elif env_name == "AntTruncatedObs-v2":
            x = next_obs[:, 0]
            not_done = 	np.isfinite(next_obs).all(axis=-1) \
                        * (x >= 0.2) \
                        * (x <= 1.0)

            done = ~not_done
            done = done[:,None]
            return done
        elif "Swimmer" in env_name or "HalfCheetah" in env_name: # No done for these two envs
            return np.array([False for _ in obs])
        else:
            raise NotImplementedError

        return done

    @torch.no_grad()
    def eval_data(self, data, update_elite_models=False):
        obs_list, action_list, next_obs_list, reward_list = \
            itemgetter("obs",'action','next_obs', 'reward')(data)
        delta_obs_list = next_obs_list - obs_list
        obs_list, action_list = self.transform_obs_action(obs_list, action_list)
        model_input = torch.cat([obs_list, action_list], dim=-1)
        predictions = functional.minibatch_inference(args=[model_input],rollout_fn=self.model.predict, batch_size=10000, cat_dim=1) # the inference size grows as model buffer increases
        groundtruths = torch.cat((delta_obs_list, reward_list), dim=1)
        eval_mse_losses, _, _  = self.model_loss(predictions, groundtruths, mse_only=True)
        if update_elite_models:
            elite_idx = np.argsort(eval_mse_losses.cpu().numpy())
            self.model.elite_model_idxes = elite_idx[:self.model.num_elite]
        return eval_mse_losses.detach().cpu().numpy(), None
        
    def reset_normalizers(self):
        self.obs_normalizer.reset()
        self.act_normalizer.reset()
        
    def update_normalizer(self, obs, action):
        self.obs_normalizer.update(obs)
        self.act_normalizer.update(action)
    
    def transform_obs_action(self, obs, action):
        obs = self.obs_normalizer.transform(obs)
        action = self.act_normalizer.transform(action)
        return obs, action

    def update(self, data_batch):
        obs_batch, action_batch, next_obs_batch, reward_batch = \
            itemgetter("obs",'action','next_obs', 'reward')(data_batch)

        delta_obs_batch = next_obs_batch - obs_batch
        obs_batch, action_batch = self.transform_obs_action(obs_batch, action_batch)
        model_input = torch.cat([obs_batch, action_batch], dim=-1)
        prior_sample_size = int(obs_batch.shape[0] * self.prior_ratio)
        if self.prior_ratio: 
            model_input_prior, groundtruths_prior = self.priority_buffer.sample(prior_sample_size)
        else:
            idx_sample = np.random.randint(0, obs_batch.shape[0], size=prior_sample_size)
            model_input_prior = model_input[idx_sample]
            delta_obs_batch_prior = delta_obs_batch[idx_sample]
            reward_batch_prior = reward_batch[idx_sample]
            groundtruths_prior = torch.cat((delta_obs_batch_prior, reward_batch_prior), dim=-1)
        
        #predict with model
        model_input = torch.concatenate((model_input, model_input_prior),dim=0)
        predictions = self.model.predict(model_input)
        #compute training loss
        groundtruths = torch.cat((delta_obs_batch, reward_batch), dim=-1)
        groundtruths = torch.concatenate((groundtruths, groundtruths_prior),dim=0)
        
        train_mse_losses, train_var_losses, sqr_losses = self.model_loss(predictions, groundtruths)
        train_mse_loss = torch.sum(train_mse_losses)
        train_var_loss = torch.sum(train_var_losses)
        train_transition_loss = train_mse_loss + train_var_loss
        train_transition_loss += 0.01 * torch.sum(self.model.max_logvar) - 0.01 * torch.sum(self.model.min_logvar)
        if self.use_weight_decay:
            decay_loss = self.model.get_decay_loss()
            train_transition_loss += decay_loss
        else:
            decay_loss = None
        #udpate transition model
        self.model_optimizer.zero_grad()
        train_transition_loss.backward()
        self.model_optimizer.step()
        
        # prior_ratio
        weight = sqr_losses.mean(dim=0).sum(dim=-1).detach()
        idx = weight.argsort(descending=True)
        num_prior = int( model_input.shape[0] * self.prior_ratio)
        temp = [model_input[idx][:num_prior], groundtruths[idx][:num_prior], weight[idx][:num_prior]]
        self.priority_buffer.add(temp)
        
        return {
            "loss/train_model_loss_mse": train_mse_loss.item(),
            "loss/train_model_loss_var": train_var_loss.item(),
            "loss/train_model_loss": train_var_loss.item(),
            "loss/decay_loss": decay_loss.item() if decay_loss is not None else 0,
            "misc/max_std": self.model.max_logvar.mean().item(),
            "misc/min_std": self.model.min_logvar.mean().item()
        }

    def model_loss(self, predictions, groundtruths, mse_only=False):
        pred_means, pred_logvars = predictions
        if self.inc_var_loss and not mse_only:
            # Average over batch and dim, sum over ensembles.
            inv_var = torch.exp(-pred_logvars)
            sqr_losses = torch.pow(pred_means - groundtruths, 2) 
            mse_losses = torch.mean(torch.mean(sqr_losses * inv_var, dim=-1), dim=-1)
            var_losses = torch.mean(torch.mean(pred_logvars, dim=-1), dim=-1)
            
        elif mse_only:
            sqr_losses = torch.pow(pred_means - groundtruths, 2) 
            mse_losses = torch.mean(sqr_losses, dim=(1, 2))
            var_losses = None
        else:
            assert 0
        return mse_losses, var_losses, sqr_losses

    @torch.no_grad()  
    def predict(self, obs, act, deterministic=False):
        if len(obs.shape) == 1:
            obs = obs[None,]
            act = act[None,]
        if not isinstance(obs, torch.Tensor):
            obs = torch.FloatTensor(obs).to(util.device)
        if not isinstance(act, torch.Tensor):
            act = torch.FloatTensor(act).to(util.device)

        scaled_obs, scaled_act = self.transform_obs_action(obs, act)
        
        model_input = torch.cat([scaled_obs, scaled_act], dim=-1)
        pred_diff_means, pred_diff_logvars = self.model.predict(model_input)
        pred_diff_means = pred_diff_means.detach().cpu().numpy()
        #add curr obs for next obs
        obs = obs.detach().cpu().numpy()
        act = act.detach().cpu().numpy()
        ensemble_model_stds = pred_diff_logvars.exp().sqrt().detach().cpu().numpy()
        if deterministic:
            pred_diff_means = pred_diff_means
        else:
            pred_diff_means = pred_diff_means + np.random.normal(size=pred_diff_means.shape) * ensemble_model_stds

        num_models, batch_size, _ = pred_diff_means.shape
        model_idxes = np.random.choice(self.model.elite_model_idxes, size=batch_size)
        batch_idxes = np.arange(0, batch_size)

        pred_diff_samples = pred_diff_means[model_idxes, batch_idxes]

        next_obs, rewards = pred_diff_samples[:, :-1] + obs, pred_diff_samples[:, -1]
        terminals = self._termination_fn(self.env_name, obs, act, next_obs)

        assert(type(next_obs) == np.ndarray)
        return next_obs, rewards, terminals

    def update_best_snapshots(self, val_losses):
        updated = False
        for i in range(len(val_losses)):
            current_loss = val_losses[i]
            best_loss = self.best_snapshot_losses[i]
            improvement = (best_loss - current_loss) / best_loss
            if improvement > 0.01:
                self.best_snapshot_losses[i] = current_loss
                self.save_model_snapshot(i)
                updated = True
                improvement = (best_loss - current_loss) / best_loss
                # print('epoch {} | updated {} | improvement: {:.4f} | best: {:.4f} | current: {:.4f}'.format(epoch, i, improvement, best, current))
        
        return updated

    def reset_best_snapshots(self):
        self.model_best_snapshots = [deepcopy(self.model.ensemble_models[idx].state_dict()) for idx in range(self.model.ensemble_size)]
        self.best_snapshot_losses = [1e10 for _ in range(self.model.ensemble_size)]

    def save_model_snapshot(self, idx):
        self.model_best_snapshots[idx] = deepcopy(self.model.ensemble_models[idx].state_dict())

    def load_best_snapshots(self):
        self.model.load_state_dicts(self.model_best_snapshots)

    def save_model(self, info):
        save_dir = os.path.join(util.logger.log_path, 'models')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model_save_dir = os.path.join(save_dir, "ite_{}".format(info))
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        for network_name, network in self.networks.items():
            save_path = os.path.join(model_save_dir, network_name + ".pt")
            torch.save(network, save_path)

    def load_model(self, info):
        save_dir = os.path.join(util.logger.log_path, 'models')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model_save_dir = os.path.join(save_dir, "ite_{}".format(info))
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        for network_name, network in self.networks.items():
            save_path = os.path.join(model_save_dir, network_name + ".pt")
            torch.save(network, save_path)
