import os

import lpips
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

from code import arch_v2
from code.data import get_dataset
from code.realnvp_v2 import RealNVP
from code.utils import EMA


_LOG2PI = np.log(2 * np.pi).item()


class ForwardOp:
    def __init__(self, type_: str, keep_shape=False, mask=None, shape=None, A=None):
        assert (type_ in ('top_half', 'bottom_half', 'left_half', 'mar', 'cs', 'blur', 'grayscale', 'blur2')
                or type_.startswith('digit'))

        self.type_ = type_
        self.keep_shape = keep_shape
        self.mask = mask
        self.shape = shape

        if self.type_ == 'cs':
            assert A is not None
            self.A = A
        elif self.type_ == 'mar':
            assert mask is not None and mask.dtype == torch.bool
            self.mask = mask
        elif self.type_.startswith('digit'):
            digit = int(self.type_[-1])
            assert digit in range(10)
            ckpt_path = arch_v2.DigitClassifier.get_ckpt_path(digit)
            self.net = arch_v2.DigitClassifier()
            self.net.load_state_dict(torch.load(ckpt_path))
            self.net.cuda()
            self.net.eval()
            for p in self.net.parameters():
                p.requires_grad_(False)
            print(f'Loaded DigitClasifier({digit}) checkpoint')
            self.type_ = 'module'

    def observe(self, x: torch.Tensor):
        assert x.ndim == 4

        if self.type_ == 'top_half':
            x_obs, x_mis = x.chunk(2, dim=2)
        elif self.type_ == 'bottom_half':
            x_mis, x_obs = x.chunk(2, dim=2)
        elif self.type_ == 'left_half':
            x_obs, x_mis = x.chunk(2, dim=3)
        elif self.type_ == 'mar':
            assert self.mask.shape == x.shape[1:]
            if self.keep_shape:
                x_obs, x_mis = x * self.mask.float(), x * (~self.mask).float()
            else:
                x_obs, x_mis = x[:, self.mask], x[:, ~self.mask]
        elif self.type_ == 'cs':
            x_obs = x.reshape(len(x), -1) @ self.A
            x_mis = None
        elif self.type_ == 'blur':
            x_obs = F.avg_pool2d(x, 2, 2)
            assert x_obs.shape[2] == x.shape[2] // 2
            assert x_obs.shape[3] == x.shape[3] // 2
            x_mis = None
        elif self.type_ == 'blur2':
            x_obs = F.avg_pool2d(x, 4, 4)
            assert x_obs.shape[2] == x.shape[2] // 4
            assert x_obs.shape[3] == x.shape[3] // 4
            x_mis = None
        elif self.type_ == 'grayscale':
            x_obs = x.mean(dim=1, keepdim=True)
            x_mis = None
        elif self.type_ == 'module':
            x_obs = torch.sigmoid(self.net(x))
            x_mis = None

        return x_obs, x_mis

    def combine(self, x_obs: torch.Tensor, x_mis: torch.Tensor) -> torch.Tensor:
        if self.type_ in ('top_half', 'bottom_half', 'left_half'):
            assert x_obs.ndim == x_mis.ndim == 4

        if self.type_ == 'top_half':
            x = torch.cat([x_obs, x_mis], dim=2)
        elif self.type_ == 'bottom_half':
            x = torch.cat([x_mis, x_obs], dim=2)
        elif self.type_ == 'left_half':
            x = torch.cat([x_obs, x_mis], dim=3)
        elif self.type_ == 'mar':
            if self.keep_shape:
                x = x_obs + x_mis
            else:
                raise NotImplementedError()
        elif self.type_ == 'cs':
            raise ValueError('Cannot combine CS')
        elif self.type_ == 'net':
            raise ValueError('Cannot combine net')

        return x

    @torch.no_grad()
    def visualize(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 4
        N, C, H, W = x.shape
        x = x.clone()

        if self.type_ == 'top_half':
            x[:, :, H//2:, :] = 0.0
        elif self.type_ == 'bottom_half':
            x[:, :, :H//2, :] = 0.0
        elif self.type_ == 'left_half':
            x[:, :, :, W//2:] = 0.0
        elif self.type_ == 'mar':
            x[:, ~self.mask] = 0.0
        elif self.type_ == 'cs':
            x = torch.zeros_like(x)
        elif self.type_ == 'blur':
            x = F.avg_pool2d(x, 2, 2)
            x = F.upsample_nearest(x, scale_factor=2)
        elif self.type_ == 'blur2':
            x = F.avg_pool2d(x, 4, 4)
            x = F.upsample_nearest(x, scale_factor=4)
        elif self.type_ == 'grayscale':
            x = x.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
        elif self.type_ == 'net':
            x = torch.zeros_like(x)

        return x


class DistFunc(nn.Module):
    def __init__(self, loss_type):
        super().__init__()
        if loss_type == 'l2':
            def _fn(x, y):
                assert x.shape == y.shape
                return (x - y).view(len(x), -1).pow(2).sum(dim=1)
        elif loss_type == 'lpips_vgg':
            self._lpips_loss = lpips.LPIPS(net='vgg')
            assert isinstance(self._lpips_loss, nn.Module)
            def _fn(x, y):
                # Assume x, y are scaled to be in [0, 1]
                assert x.shape == y.shape
                assert x.min() >= 0.0 and x.max() <= 1.0
                assert y.min() >= 0.0 and y.max() <= 1.0
                x = x * 2.0 - 1.0
                y = y * 2.0 - 1.0
                return self._lpips_loss(x, y).squeeze()
        else:
            raise ValueError(f'Invalid rec_loss: {loss_type}')

        self.loss_fn = _fn

    def forward(self, x, y):
        return self.loss_fn(x, y)


def normal_logp(x: torch.Tensor):
    D = np.prod(x.shape[1:]).item()
    x = x.reshape(-1, D)
    log_px = -0.5 * (D * _LOG2PI + x.pow(2).sum(dim=1))
    assert log_px.shape == (len(x),)
    return log_px


def diag_normal_logp(x: torch.Tensor, sig: float):
    D = np.prod(x.shape[1:]).item()
    x = x.reshape(-1, D)
    log_px = -0.5 * (D * (_LOG2PI + 2 * np.log(sig)) + x.pow(2).sum(dim=1) / sig**2)
    return log_px


def load_base_model(model_name: str, base_ckpt: str=None) -> RealNVP:
    base_model = {
        'mnist': arch_v2.RealNVP_MNIST,
        'celebahq64': arch_v2.RealNVP_CelebAHQ64_5bit,       # should be deprecated
        'celebahq64_5bit': arch_v2.RealNVP_CelebAHQ64_5bit,
        'cifar10_5bit': arch_v2.RealNVP_CIFAR10_5bit,
        'cifar10': arch_v2.RealNVP_CIFAR10,
    }[model_name]()

    if base_ckpt is not None:
        if not os.path.isfile(base_ckpt):
            raise ValueError(f'Invalid base model checkpoint: {base_ckpt}')

        dd = torch.load(base_ckpt)
        base_model.load_state_dict(dd['model_state_dict'])

        if 'ema_state_dict' in dd:
            ema = EMA(base_model, 0.99)  # decay value doesn't matter
            ema.load_state_dict(dd['ema_state_dict'])
            ema.assign(base_model)

        for p in base_model.parameters():
            p.requires_grad_(False)
        base_model.eval()

    return base_model


def get_eval_data(dataset: str, data_root: str=None):
    if dataset == 'mnist':
        dataset = get_dataset('mnist', split='test', data_root=data_root)
    elif dataset == 'celebahq64':
        dataset = get_dataset('celebahq64', split='valid', data_root=data_root)
    elif dataset.startswith('cifar10'):
        dataset = get_dataset('cifar10', split='test', data_root=data_root)
    else:
        raise ValueError(f'Invalid dataset name: {dataset}')

    return dataset


def compute_variance(xs: torch.Tensor):
    if xs.ndim == 4:
        xs = xs[None, ...]
    assert xs.ndim == 5

    N, K = xs.shape[:2]
    vs = xs.var(dim=1)
    vs = vs / vs.view(N, -1).max(dim=1)[0].view(N, 1, 1, 1)
    vs = vs.mean(dim=1)
    assert vs.shape == (N, xs.shape[3], xs.shape[4])
    return vs


def plot_variance(vs: torch.Tensor, fn):
    assert vs.ndim == 3
    N = len(vs)

    plt.clf()
    fig, axes = plt.subplots(1, N, figsize=(4 * N, 4))
    for ax, v in zip(axes, vs):
        sns.heatmap(v, vmin=0.0, vmax=1.0, cbar=False, ax=ax)
        ax.set(xlabel=None, ylabel=None, xticks=[], yticks=[], aspect=1.0)
    fig.tight_layout(pad=2.5)
    fig.savefig(fn)
    plt.close(fig)


def test_normal_logp():
    n, d = 1000, 20
    mvn = torch.distributions.multivariate_normal.MultivariateNormal(
        loc=torch.zeros(d), covariance_matrix=torch.eye(d))

    x = torch.randn(n, d)
    out = normal_logp(x)
    expected = mvn.log_prob(x)

    assert out.shape == expected.shape == (n,)
    assert torch.allclose(out, expected)


def test_diag_normal_logp():
    n, d = 1000, 20
    sig = 0.7
    mvn = torch.distributions.multivariate_normal.MultivariateNormal(
        loc=torch.zeros(d), covariance_matrix=torch.eye(d) * sig**2)

    x = torch.randn(n, d)
    out = diag_normal_logp(x, sig)
    expected = mvn.log_prob(x)

    assert out.shape == expected.shape == (n,)
    assert torch.allclose(out, expected)
