from abc import ABC, abstractmethod
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import defaults
from .normalization import Normalizer
from .train import get_optimizer, epochal_training, supervised_loss, L2Loss
from .torch_util import device, numpyify, torchify, Module, mlp, random_indices, sequence_mean


class BaseModel(ABC):
    @abstractmethod
    def sample(self, states, actions):
        """
        Returns a sample of (s', r) given (s, a)
        """
        pass


class RandomChoiceModel(BaseModel):
    def __init__(self, models):
        self.models = models

    def sample(self, states, actions):
        return random.choice(self.models).sample(states, actions)


class AverageModel(BaseModel):
    def __init__(self, models):
        self.models = models

    def sample(self, states, actions):
        all_next_states, all_rewards = [], []
        for model in self.models:
            next_states, rewards = model.sample(states, actions)
            all_next_states.append(next_states)
            all_rewards.append(rewards)
        next_states = sequence_mean(all_next_states)
        rewards = sequence_mean(all_rewards)
        return next_states, rewards


class DeterministicDynamicsModel(Module, BaseModel):
    def __init__(self, state_dim, action_dim, hidden_dim, n_hidden_layers,
                 layer_factory=nn.Linear, optimizer_factory=defaults.OPTIMIZER):
        super().__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.state_normalizer = Normalizer(state_dim)
        self.diff_normalizer = Normalizer(state_dim)
        self.reward_normalizer = Normalizer(1)

        input_dim = state_dim + action_dim
        output_dim = state_dim + 1
        self.net = mlp([input_dim, *([hidden_dim] * n_hidden_layers), output_dim], layer_factory=layer_factory)
        self.compute_loss = supervised_loss(self.net, L2Loss())
        self.to(device)
        self.optimizer = optimizer_factory(self.net.parameters())

    def forward(self, states, actions, grad=True):
        normalized_states = self.state_normalizer(states)
        with torch.set_grad_enabled(grad):
            outputs = self.net(torch.cat([normalized_states, actions], dim=1))
        diffs = self.diff_normalizer.unnormalize(outputs[:,:self.state_dim])
        rewards = self.reward_normalizer.unnormalize(outputs[:,self.state_dim])
        return states + diffs, rewards

    def _format_for_training(self, states, actions, diffs, rewards):
        normalized_states = self.state_normalizer(states)
        normalized_diffs = self.diff_normalizer(diffs)
        inputs = torch.cat([normalized_states, actions], dim=1)
        normalized_rewards = self.reward_normalizer(rewards)
        targets = torch.cat([normalized_diffs, normalized_rewards], dim=1)
        return inputs, targets

    def fit(self, buffer, epochs, **kwargs):
        states, actions, next_states, rewards = buffer.get('states', 'actions', 'next_states', 'rewards')
        diffs = next_states - states
        rewards = buffer.get('rewards').unsqueeze(1)
        self.state_normalizer.fit(states)
        self.diff_normalizer.fit(diffs)
        self.reward_normalizer.fit(rewards)
        inputs, targets = self._format_for_training(states, actions, diffs, rewards)
        return epochal_training(self.compute_loss, self.optimizer, [inputs, targets], epochs, **kwargs)

    def update(self, buffer, batch_size=defaults.BATCH_SIZE):
        states, actions, next_states, rewards, _ = buffer.sample(batch_size)
        diffs = next_states - states
        rewards = rewards.unsqueeze(1)
        inputs, targets = self._format_for_training(states, actions, diffs, rewards)
        loss = self.compute_loss(inputs, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sample(self, states, actions):
        return self.forward(states, actions)

    def mean(self, states, actions):
        return self.forward(states, actions)


# def diagonal_gaussian(mean, std):
#     return ptd.Independent(ptd.Normal(loc=mean, scale=std), 1)

class GaussianDynamicsModel(Module, BaseModel):
    def __init__(self, state_dim, action_dim, hidden_dim, trunk_layers, head_hidden_layers,
                 init_min_log_var=-10.0, init_max_log_var=1.0,
                 layer_factory=nn.Linear, optimizer_factory=defaults.OPTIMIZER):
        Module.__init__(self)

        self.state_dim = state_dim
        self.action_dim = action_dim
        input_dim = state_dim + action_dim
        output_dim = state_dim + 1

        self.min_log_var = nn.Parameter(torch.full([output_dim], init_min_log_var, device=device))
        self.max_log_var = nn.Parameter(torch.full([output_dim], init_max_log_var, device=device))
        self.state_normalizer = Normalizer(state_dim)

        trunk_dims = [input_dim] + [hidden_dim] * trunk_layers
        head_dims = [hidden_dim] * (head_hidden_layers + 1) + [output_dim]
        self.trunk = mlp(trunk_dims, layer_factory=layer_factory, output_activation=nn.ReLU())
        self.diff_net = mlp(head_dims, layer_factory=layer_factory)
        self.log_var_net = mlp(head_dims, layer_factory=layer_factory)
        self.to(device)
        self.optimizer = optimizer_factory(
            list(self.trunk.parameters()) +
            list(self.diff_net.parameters()) +
            list(self.log_var_net.parameters()) +
            [self.min_log_var, self.max_log_var]
        )

    def _forward(self, states, actions):
        inputs = torch.cat([self.state_normalizer(states), actions], dim=1)
        shared_hidden = self.trunk(inputs)
        diffs = self.diff_net(shared_hidden)
        means = diffs + torch.cat([states, torch.zeros([len(inputs), 1], device=device)], dim=1)
        log_vars = self.log_var_net(shared_hidden)
        log_vars = self.max_log_var - F.softplus(self.max_log_var - log_vars)
        log_vars = self.min_log_var + F.softplus(log_vars - self.min_log_var)
        return means, log_vars

    def compute_loss(self, states, actions, targets):
        means, log_vars = self._forward(states, actions)
        inv_vars = torch.exp(-log_vars)
        squared_errors = torch.sum((targets - means)**2 * inv_vars, dim=1)
        log_dets = torch.sum(log_vars, dim=1)
        mle_loss = torch.mean(squared_errors + log_dets)
        return mle_loss + 0.01 * (self.max_log_var.sum() - self.min_log_var.sum())

    def fit(self, buffer, epochs, **kwargs):
        states, actions, next_states, rewards, _ = buffer.get()
        self.state_normalizer.fit(states)
        targets = torch.cat([next_states, rewards.unsqueeze(1)], dim=1)
        return epochal_training(self.compute_loss, self.optimizer, [states, actions, targets], epochs, **kwargs)

    def update(self, buffer, batch_size=defaults.BATCH_SIZE):
        states, actions, next_states, rewards, _ = buffer.sample(batch_size)
        targets = torch.cat([next_states, rewards.unsqueeze(1)], dim=1)
        self.optimizer.zero_grad()
        self.compute_loss(states, actions, targets).backward()
        self.optimizer.step()

    def sample(self, states, actions):
        means, log_vars = self._forward(states, actions)
        stds = torch.exp(log_vars).sqrt()
        samples = means + stds * torch.randn_like(means)
        return samples[:,:-1], samples[:,-1]

    def mean(self, states, actions):
        means, _ = self._forward(states, actions)
        return means[:,:-1], means[:,-1]


class OracleDynamics(BaseModel):
    def __init__(self, env):
        env_class = env.unwrapped.__class__
        assert hasattr(env_class, 'oracle_dynamics')
        self.env = env_class()

    def mean(self, states, actions):
        if states.ndim == 1 and actions.ndim == 1:
            next_state, reward = self.env.oracle_dynamics(numpyify(states), numpyify(actions))
            return torchify(next_state), reward
        else:
            next_states, rewards = [], []
            for s, a in zip(states, actions):
                next_state, reward = self.env.oracle_dynamics(numpyify(s), numpyify(a))
                next_states.append(torchify(next_state))
                rewards.append(float(reward))
            return torch.stack(next_states), torch.tensor(rewards, device=device)

    def sample(self, states, actions):
        return self.mean(states, actions)