import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from code.exp_utils import normal_logp
from code.realnvp_v2 import RealNVP, ConditionalRealNVP
from code.utils import seed_everything


class CondModel(nn.Module):
    def forward(self):
        raise NotImplementedError

    def sample(self, n, device=None):
        raise NotImplementedError

    def log_prob(self, x):
        raise NotImplementedError


class CondModel_RealNVP(CondModel):
    def __init__(self, *, image_shape, d_hidden, n_blocks, n_scales):
        super().__init__()
        self.realnvp = RealNVP(image_shape=image_shape, d_hidden=d_hidden, n_blocks=n_blocks,
                               n_scales=n_scales, logit_eps=None)
        self.D = int(np.prod(image_shape))
        self.image_shape = image_shape

    def sample(self, n, temp=1.0, device=None):
        z = self.realnvp.sample_prior(n=n, temp=temp, device=device)
        x, logdet = self.realnvp(z, inverse=True)
        log_pz = self.realnvp.log_prior(z)
        log_px = log_pz - logdet
        assert log_pz.shape == logdet.shape == log_px.shape == (n,)
        return x, log_px

    def log_prob(self, x):
        assert x.shape == (len(x), *self.image_shape)
        log_px = self.realnvp.log_prob(x)[0]
        assert log_px.shape == (len(x),)
        return log_px


class CondModel_GMM(CondModel):
    def __init__(self, *, image_shape, n_mix: int):
        super().__init__()
        self.n_mix = n_mix
        self.image_shape = image_shape
        self.D = int(np.prod(image_shape))

        self.w = nn.Parameter(torch.zeros(n_mix), requires_grad=True)
        self.mu = nn.Parameter(torch.randn(n_mix, self.D) / np.sqrt(self.D) * 2, requires_grad=True)
        self.logs = nn.Parameter(torch.zeros(n_mix, self.D), requires_grad=True)

    def sample(self, n, device=None, temp=None):
        weights = F.softmax(self.w.detach().cpu(), dim=0).numpy()
        indices = np.random.choice(self.n_mix, n, replace=True, p=weights)
        mu = self.mu[indices]
        scale = self.logs[indices].exp()
        assert mu.shape == scale.shape == (n, self.D)

        x = torch.randn(n, self.D, device=device) * scale + mu
        assert x.shape == (n, self.D)

        x = x.view(n, *self.image_shape)
        log_px = self.log_prob(x)
        assert log_px.shape == (n,)

        return x, log_px

    def log_prob(self, x):
        assert x.shape == (len(x), *self.image_shape)
        x = x.view(len(x), self.D)
        x_unit = (x[:, None] - self.mu[None]) / (self.logs[None].exp())
        assert x_unit.shape == (len(x), self.n_mix, self.D)
        log_px = normal_logp(x_unit.view(-1, self.D)).view(len(x), self.n_mix) - self.logs.sum(dim=1)
        assert log_px.shape == (len(x), self.n_mix)
        log_px += F.log_softmax(self.w, dim=0)
        log_px = torch.logsumexp(log_px, dim=1)

        assert log_px.shape == (len(x),)
        return log_px


class AmortizedCondModel_RealNVP(CondModel):
    def __init__(self, *, image_shape, d_hidden, n_blocks, n_scales):
        super().__init__()
        self.realnvp = ConditionalRealNVP(
            image_shape=image_shape, d_hidden=d_hidden, n_blocks=n_blocks,
            n_scales=n_scales, logit_eps=None)
        self.D = int(np.prod(image_shape))
        self.image_shape = image_shape

    def sample(self, n, x_cond, temp=1.0, device=None):
        z = self.realnvp.sample_prior(n=n, temp=temp, device=device)
        x, logdet = self.realnvp(z, x_cond=x_cond, inverse=True)
        log_pz = self.realnvp.log_prior(z)
        log_px = log_pz - logdet
        assert log_pz.shape == logdet.shape == log_px.shape == (n,)
        return x, log_px

    def log_prob(self, x):
        assert x.shape == (len(x), *self.image_shape)
        log_px = self.model.log_prob(x, x_cond=x_cond)[0]
        assert log_px.shape == (len(x),)
        return log_px


def create_vi_model(dataset, model_name, logit_eps: float=1e-4, seed=None) -> nn.Module:
    # Create cond model
    seed_everything(seed)
    kwargs = {}

    if model_name.startswith('naive'):
        model_cls = RealNVP
        kwargs.update({'logit_eps': logit_eps})
    elif model_name.startswith('vi') or model_name.startswith('realnvp'):
        model_cls = CondModel_RealNVP

    if dataset == 'mnist':
        kwargs['image_shape'] = (1, 28, 28)
        if model_name in ('vi', 'vi_default', 'naive', 'naive_default', 'realnvp'):
            kwargs.update({'d_hidden': 32, 'n_blocks': 3, 'n_scales': 3})

        elif model_name in ('vi_big', 'naive_big'):
            kwargs.update({'d_hidden': 32, 'n_blocks': 6, 'n_scales': 3})

        elif model_name in ('vi_small', 'naive_small', 'realnvp_small'):
            kwargs.update({'d_hidden': 16, 'n_blocks': 3, 'n_scales': 3})

        elif model_name in ('vi_tiny', 'naive_tiny', 'realnvp_tiny'):
            kwargs.update({'d_hidden': 8, 'n_blocks': 3, 'n_scales': 3})

        elif model_name in ('vi_minute', 'naive_minute', 'realnvp_minute'):
            kwargs.update({'d_hidden': 4, 'n_blocks': 3, 'n_scales': 3})

        elif model_name in ('vi_micro', 'naive_micro', 'realnvp_micro'):
            kwargs.update({'d_hidden': 4, 'n_blocks': 2, 'n_scales': 3})

        elif model_name in ('vi_original', 'naive_original', 'realnvp_original'):
            kwargs.update({'d_hidden': 32, 'n_blocks': 8, 'n_scales': 3})

        else:
            raise ValueError(f'Invalid model_name: {model_name}')

    elif dataset == 'celebahq64':
        kwargs['image_shape'] = (3, 64, 64)
        if model_name in ('vi_small', 'naive_small'): # old `realnvp`
            kwargs.update({'d_hidden': 32, 'n_blocks': 6, 'n_scales': 4})
        elif model_name in ('vi', 'vi_default', 'naive', 'naive_default'): # old `realnvp_big`
            kwargs.update({'d_hidden': 48, 'n_blocks': 4, 'n_scales': 4})
        else:
            raise ValueError(f'Invalid model_name: {model_name}')

    elif dataset == 'cifar10_5bit':
        kwargs['image_shape'] = (3, 32, 32)
        if model_name in ('vi_tiny', 'naive_tiny'):
            kwargs.update({'d_hidden': 16, 'n_blocks': 3, 'n_scales': 3})
        elif model_name in ('vi_small', 'naive_small'):
            kwargs.update({'d_hidden': 32, 'n_blocks': 3, 'n_scales': 3})
        elif model_name in ('vi_default', 'naive_default'):
            kwargs.update({'d_hidden': 48, 'n_blocks': 3, 'n_scales': 3})
        else:
            raise ValueError(f'Invalid model_name: {model_name}')

    else:
        raise ValueError(f'Invalid dataset name: {dataset}')

    cond_model = model_cls(**kwargs)
    return cond_model
