import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils
from agent.dreamer import DreamerAgent, stop_gradient
import agent.dreamer_utils as common


class Disagreement(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None):
        super().__init__()
        if pred_dim is None: pred_dim = obs_dim
        self.ensemble = nn.ModuleList([
            nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim),
                          nn.ReLU(), nn.Linear(hidden_dim, pred_dim))
            for _ in range(n_models)
        ])

    def forward(self, obs, action, next_obs):
        #import ipdb; ipdb.set_trace()
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        errors = []
        for model in self.ensemble:
            next_obs_hat = model(torch.cat([obs, action], dim=-1))
            model_error = torch.norm(next_obs - next_obs_hat,
                                     dim=-1,
                                     p=2,
                                     keepdim=True)
            errors.append(model_error)

        return torch.cat(errors, dim=1)

    def get_disagreement(self, obs, action):
        assert obs.shape[0] == action.shape[0]

        preds = []
        for model in self.ensemble:
            next_obs_hat = model(torch.cat([obs, action], dim=-1))
            preds.append(next_obs_hat)
        preds = torch.stack(preds, dim=0)
        return torch.var(preds, dim=0).mean(dim=-1)


class Plan2Explore(DreamerAgent):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        in_dim = self.wm.inp_size
        pred_dim = self.wm.embed_dim
        self.hidden_dim = pred_dim
        self.reward_free = True

        self.disagreement = Disagreement(in_dim, self.act_dim,
                                         self.hidden_dim, pred_dim=pred_dim).to(self.device)

        # optimizers
        self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
        self.disagreement.train()
        self.requires_grad_(requires_grad=False)

    def update_disagreement(self, obs, action, next_obs, step):
        metrics = dict()

        error = self.disagreement(obs, action, next_obs)

        loss = error.mean()

        metrics.update(self.disagreement_opt(loss, self.disagreement.parameters()))

        metrics['disagreement_loss'] = loss.item()

        return metrics

    def compute_intr_reward(self, seq):
        obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:])
        intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device)
        if len(action.shape) > 2:
            B, T, _ = action.shape
            obs = obs.reshape(B*T, -1)
            action = action.reshape(B*T, -1)
            reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1)
        else:
            reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1)
        intr_rew[1:] = reward
        return intr_rew

    def update(self, data, step):
        metrics = {}
        B, T, _ = data['action'].shape
        state, outputs, mets = self.wm.update(data, state=None)
        metrics.update(mets)
        start = outputs['post']
        start = {k: stop_gradient(v) for k,v in start.items()}
        if self.reward_free:
            T = T-1
            inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1)
            action = data['action'][:, 1:].reshape(B*T, -1)
            out = stop_gradient(outputs['embed'][:, 1:]).reshape(B*T, -1)
            with common.RequiresGrad(self.disagreement):
                with torch.cuda.amp.autocast(enabled=self._use_amp):
                    metrics.update(
                        self.update_disagreement(inp, action, out, step))
            metrics.update(self._task_behavior.update(
                self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward))
        else:
            reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean #.mode()
            metrics.update(self._task_behavior.update(
                self.wm, start, data['is_terminal'], reward_fn))
        return state, metrics