import torch
import torch.nn as nn
import numpy as np
from .basic_blocks import InvertibleLinear, InvertibleResNet1d, GlowRevNet1d, GlowRevNet2d
from .density_modeling_utils import GaussianDiag
from .kernel_fn import *
from .utils import weighted_max, normalize, LinearWrapper, batch_subset_kernel
from .kernel_approx import Nystrom_PerronFreboniusOperator

import os
from tqdm import tqdm

import pickle

class Glow1d(nn.Module):
    def __init__(self, input_depth, steps):
        super(Glow1d, self).__init__()
        self.input_depth = input_depth
        self.steps = steps
        self.operators = nn.ModuleList([GlowRevNet1d(input_depth) for _ in range(steps)])

        self.prior = GaussianDiag(input_depth, 0., 0., trainable=False)
    
    def forward(self, x):
        b = x.size(0)

        logpk = 0.
            
        Fx = x.view(b, -1)
        for i in range(self.steps):
            Fx, logdet = self.operators[i](Fx, ignore_logdet=False)
            logpk = logpk + logdet

        prior_logp = self.prior.logp(Fx)
        logpk = logpk + prior_logp
        dim = float(np.prod(Fx.shape[1:]))
        logpk = logpk / dim

        return logpk
    
    def sample(self, n_samples, temp=1.):
        with torch.no_grad():
            sample = self.prior.sample(n_samples, temp)
            x = sample
            for i in reversed(range(self.steps)):
                x = self.operators[i].inverse(x)

        return x

class Glow2dStage(nn.Module):
    def __init__(self, in_channels, steps):
        super(Glow2dStage, self).__init__()
        self.ops = nn.ModuleList([GlowRevNet2d(in_channels, do_actnorm=False) for _ in range(steps)])
    
    def forward(self, x):
        Fx = x
        logdet = 0.
        for op in self.ops:
            Fx, _logdet = op(Fx)
            logdet = logdet + _logdet
        return Fx, logdet

    def inverse(self, Fx):
        x = Fx
        for op in reversed(self.ops):
            x = op.inverse(x)
        return x

class Glow2d(nn.Module):
    def __init__(self, input_shape, stages, steps):
        super(Glow2d, self).__init__()
        c, h, w = input_shape

        self.init_squeeze_factor = 2
        c *= (self.init_squeeze_factor ** 2)
        h = h // self.init_squeeze_factor
        w = w // self.init_squeeze_factor

        _stages = []
        for i in range(stages):
            _stages.append(Glow2dStage(c, steps[i]))
            c *= 4
            h = h // 2
            w = w // 2
        self.stages = nn.ModuleList(_stages)
        self.final_shape = (c, h, w)

        self.prior = GaussianDiag(np.prod(self.final_shape), 0., 0., trainable=False)
    
    @staticmethod
    def _squeeze(tensor, factor=2):
        b, c, h, w = tensor.shape
        tensor = (tensor.view(b, c, h // factor, factor, w // factor, factor)
                        .permute(0, 1, 3, 5, 2, 4)
                        .reshape(b, c * (factor ** 2), h // factor, w // factor))
        return tensor

    @staticmethod
    def _unsqueeze(tensor, factor=2):
        b, c, h, w = tensor.shape
        tensor = (tensor.view(b, c // (factor ** 2), factor, factor, h, w)
                        .permute(0, 1, 4, 2, 5, 3)
                        .reshape(b, c // (factor ** 2), h * factor, w * factor))
        return tensor
    
    def forward(self, x):
        b = x.size(0)

        logpk = 0.
        Fx = self._squeeze(x, self.init_squeeze_factor)
        for s in self.stages:
            Fx, logdet = s(Fx)
            Fx = self._squeeze(Fx, 2)
            logpk = logpk + logdet

        prior_logp = self.prior.logp(Fx.view(b, -1))
        logpk = logpk + prior_logp
        dim = float(np.prod(Fx.shape[1:]))
        logpk = logpk / dim

        return logpk
    
    def sample(self, n_samples, temp=1.):
        with torch.no_grad():
            sample = self.prior.sample(n_samples, temp).view(n_samples, *self.final_shape)
            x = sample
            for s in reversed(self.stages):
                x = self._unsqueeze(x, 2)
                x = s.inverse(x)
            x = self._unsqueeze(x, self.init_squeeze_factor)

        return x

class InvResNet1d(nn.Module):
    def __init__(self, input_depth, steps):
        super(InvResNet1d, self).__init__()
        self.input_depth = input_depth
        self.steps = steps
        self.operators = nn.ModuleList([InvertibleResNet1d(input_depth) for _ in range(steps)])

        self.prior = GaussianDiag(input_depth, 0., 0., trainable=False)
    
    def forward(self, x):
        b = x.size(0)

        logpk = 0.
            
        K1 = x.view(b, -1)
        K1.requires_grad = True
        for i in range(self.steps):
            K1, logdet = self.operators[i](K1, ignore_logdet=False)
            logpk = logpk + logdet

        prior_logp = self.prior.logp(K1)
        logpk = logpk + prior_logp
        dim = float(np.prod(K1.shape[1:]))
        logpk = logpk / dim

        return logpk
    
    def sample(self, n_samples, temp=1.):
        with torch.no_grad():
            sample = self.prior.sample(n_samples, temp)
            K0 = sample
            for i in reversed(range(self.steps)):
                K0 = self.operators[i].inverse(K0)

        return K0

class Kernel_Perron_Frobenius(nn.Module):
    def __init__(self, kernel_obj, features, labels=None, prior_sample_fn=None, nystrom_compression=False, nystrom_points=1000, epsilon=0., p_dim=-1, preimage_method='gi', device=torch.device('cpu')):
        super(Kernel_Perron_Frobenius, self).__init__()
        self.kernel = kernel_obj
        self.prior_sample_fn = prior_sample_fn

        self.y = nn.ParameterList([nn.Parameter(f, requires_grad=False) for f in features])
        self.y_size = self.y[0].size(0)
        self.labels = labels
        
        assert preimage_method in ['gi', 'fm', 'kkm']
        self.preimage_method = preimage_method

        if p_dim == -1:
            self.p_dim = np.sum([_y.size(1) for _y in self.y])
        else:
            self.p_dim = p_dim

        with torch.no_grad():
            b = self.y[0].size(0)
            y = torch.cat([_y.view(b, -1) for _y in self.y], dim=-1)

            # x = torch.randn_like(y)
            x = self.get_prior_sample(b).to(y.device)
            if nystrom_compression:
                op = Nystrom_PerronFreboniusOperator(x, y, self.kernel, nystrom_points, epsilon=epsilon)
            else:
                # Gxx = batch_subset_kernel(x, self.kernel, np.arange(b), np.arange(b), 1024)
                # Gyy = batch_subset_kernel(y, self.kernel, np.arange(b), np.arange(b), 1024)
                Gxx = self.kernel(x, x)
                Gyy = self.kernel(y, y)

                Gxx_aug = (Gxx + epsilon * b * torch.eye(b).to(Gxx.device))
                
                print('Eigenvalues:', Gxx_aug.symeig()[0].detach().cpu().numpy())

                Gxx_inv = Gxx_aug.inverse()
                op = LinearWrapper(Gyy.mm(Gxx_inv), trainable=False)
            self.x = nn.Parameter(x)
            self.operator = op
    
    def get_prior_sample(self, size):
        if self.prior_sample_fn is None:
            return normalize(torch.randn(size, self.p_dim) + (torch.ones(1, self.p_dim) / self.p_dim))
            # return torch.randn(size, self.p_dim)
        return self.prior_sample_fn(size)
    
    def forward(self, sample_size):
        return self.sample(sample_size)

    def sample(self, sample_size, topk=10):
        with torch.no_grad():
            x = self.x
            x_prime = self.get_prior_sample(sample_size).to(x.device)
            Gxx_prime = self.kernel(x, x_prime)
            
            K = self.operator(Gxx_prime).t()

        # y_ind = torch.randint(self.y_size, (sample_size, 10))
        # K = torch.zeros(sample_size, self.y_size, dtype=torch.float).scatter_(-1, y_ind, torch.ones_like(y_ind, dtype=torch.float)).to(self.y[0].device)

        y_sizes = [_y.size(-1) for _y in self.y]
        y_prime = []
        for _y in self.y:
            y_prime.append(self._solve_preimage(K, _y, topk=topk))

        # y_prime = []
        # for i in range(len(self.y)):
        #     y_prime.append(normalize(torch.randn(sample_size, *self.y[i].shape[1:]).to(self.y[i].device)))

        inds = K.topk(topk, dim=-1).indices.squeeze()
        # y_prime = [K]
        # inds = np.zeros((K.size(0), topk), dtype=np.int)

        return y_prime, inds
    
    def _solve_preimage(self, K, y, topk=10):
        b = K.size(0)

        K_top = K.topk(topk, dim=-1)
        inds = K_top.indices
        weight = K_top.values

        # print(weight[:16])
        # print(inds[:16])

        if self.preimage_method == 'gi': # Geodesic interpolation
            with torch.no_grad():
                mu = y[inds[:, 0].unsqueeze(1)].view(b, -1)
                mu_orig = mu
                _w_sum = weight[:, 0]
                for i in range(1, inds.size(1)):
                    v = y[inds[:, i].unsqueeze(1)].view(b, -1)
                    theta = torch.acos(torch.clamp((mu * v).view(b, -1).sum(-1), min=-1 + 1e-7, max=1 - 1e-7)).unsqueeze(1)
                    _w_sum += weight[:, i]
                    t = (weight[:, i] / _w_sum).unsqueeze(1)

                    mu_prime = (torch.sin((1 - t) * theta) * mu + torch.sin(t * theta) * v) / torch.sin(theta)
                    mu_prime = torch.where(theta == 0., mu, mu_prime)
                    mu = mu_prime
                preimage = mu.view(b, *y.shape[1:])

        elif self.preimage_method == 'fm': # Frechet mean
            with torch.no_grad():
                weight_mat = torch.zeros_like(K).to(K.device).scatter_(-1, inds, weighted_max(weight, dim=-1))
                preimage = weight_mat.mm(y)
        elif self.preimage_method == 'kernel_karcher_mean': # kernel Karcher mean
            # Estimate 
            with torch.no_grad():
                weight_mat = torch.zeros_like(K).to(K.device).scatter_(-1, inds, weighted_max(weight, dim=-1))
                mu = weight_mat.mm(y)

            mu = torch.autograd.Variable(mu, requires_grad=True)
            for i in range(50):
                K_prime = self.kernel(mu, y, requires_grad=True)
                diff = torch.pow(K_prime - K, 2).sum(-1).mean()
                grad = torch.autograd.grad(diff, mu, retain_graph=True)
                mu.data -= 1e2 * grad[0].data
            preimage = mu
        else:
            raise ValueError()

        return preimage
        

        




