import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import toim, tovec, cor_mat
from torch.utils.data import DataLoader

from sklearn.cluster import KMeans
from sklearn.utils import check_random_state
from sklearn.cluster import kmeans_plusplus

import numpy as np
import time

class InvariantSubspaceModule(nn.Module):
    def __init__(self, n_subspaces=64, subspace_dim=1, p=0, 
                    rescale_features=False, signed_features=False, random_features=False,
                    kernel_size=1, stride=1, padding=0,
                    n_epochs=1, chunks_per_update=1, warmup_iter=0, 
                    verbose=0, eps=1e-4):
        super().__init__()
        
        # subspace hparams
        self.n_subspaces = n_subspaces
        self.subspace_dim = subspace_dim

        # normalization hparams
        if signed_features: assert(subspace_dim == 1)

        self.p = np.clip(int(p*n_subspaces),0,n_subspaces-2) # p is float between 0 and 1
        self.rescale_features = rescale_features # the norm preserving rescaling from ppt
        self.signed_features = signed_features # False
        self.random_features = random_features # False

        # convolution hparams
        self.kernel_size = kernel_size # int (same kernel size in all dimensions) 
        self.stride = stride # int 
        self.padding = padding # int
        
        # training params
        self.n_epochs = n_epochs # number of passes through dataset during self.fit
        self.chunks_per_update = chunks_per_update # set to 1
        self.warmup_iter = warmup_iter # number of K-means iterations at the start of training before doing full K-subspace clustering
        
        # logging params
        self.verbose = verbose # 0,1
        
        # small number hparam
        self.eps = eps
        
        # subspaces
        self.v = None
        # self.v = nn.UninitializedParameter()
        
    def forward(self, x):
        # part 1: compute subspace norms
        if self.signed_features:
            snrm = F.conv2d(x,self.v[:,0],stride=self.stride,padding=self.padding)
        else:
            # snrm = subspace norm (not squared)
            snrm = self.subspace_norm_sq(x,self.v[:,:self.subspace_dim]).clamp(0).sqrt()
        
        # part 2: threshold
        if self.p is not None: # p is always not none
            tau = snrm.kthvalue(self.p+1,dim=1,keepdim=True)[0] # threshold
            # tau.shape = (batch_size, 1, output_size_x, output_size_y)
        else:
            tau = 0
        y = (snrm - tau).relu()
        
        # part 3: rescale
        if self.rescale_features:
            xnrm = self.input_norm(x) # shape = (batch_size, 1, output_size_x, output_size_y)
            ynrm = y.norm(dim=1,keepdim=True) # shape = (batch_size, 1, output_size_x, output_size_y)
            y = xnrm * y / ynrm.clamp(self.eps)
            
        return y
    
    def fit(self, loader, logger=None):
        # lazy init
        if self.v is None:
            self.init(loader)
            
        if self.random_features:
            self.v.data = self.svd_project_(torch.randn_like(self.v))
            if self.verbose:
                print('Using random features. Not going to fit on data')
            return None
        
        # log
        if logger is not None: logger.save('v', self.v.clone(), 0)
        
        # fit
        t0 = time.time()
        n_chunks = 0
        n_updates = 0

        for epoch in range(self.n_epochs):
            for i, x_ in enumerate(loader):
                # compute grad
                e_ = self.energy(x_,warmup=(n_updates<self.warmup_iter))
                e_.backward()
                n_chunks += 1
                
                # update
                if n_chunks % self.chunks_per_update == 0: 
                    n_updates += 1
                    self.v.data = self.svd_project_(self.v.grad)
                    self.v.grad = None
                    # print(len(self.cluster(x_).unique(return_counts=True)[0]))
                    
                    if logger is not None: logger.save('v', self.v.clone(), n_updates)

                # print
                if self.verbose and n_chunks % max(1,(self.n_epochs*len(loader) // 10)) == 0:
                    print('epoch: {}/{}, iter: {}/{}, energy: {:.5f}, elapsed time: {:.2f}, n updates: {}'.format(
                        epoch+1,self.n_epochs,i+1,len(loader),e_.item(),time.time()-t0,n_updates))
                    # clusters = self.cluster(x_)
                    # print(len(self.cluster(x_).unique(return_counts=True)[0]))
                    
                # log
                if logger is not None: logger.save('energy', e_.item(), n_chunks)
                
        
    def energy(self, x, warmup=False):
        snrm_sq = self.subspace_norm_sq(x,self.v[:,:self.subspace_dim])
        if warmup:
            snrm_sq = self.subspace_norm_sq(x, self.v[:,:1]).detach() + snrm_sq - snrm_sq.detach()
        
        e = snrm_sq.max(dim=1)[0].sum(0).mean()
        return e

    def cluster(self, x):
        snrm_sq = self.subspace_norm_sq(x,self.v[:,:self.subspace_dim])
        return snrm_sq.max(dim=1)[1]
    
    def input_norm(self, x):
        ones = torch.ones(1,x.shape[1],self.kernel_size, self.kernel_size, dtype=x.dtype, device=x.device)
        return F.conv2d(x**2,ones,stride=self.stride,padding=self.padding).clamp(0).sqrt()

    def subspace_norm_sq(self, x, v):
        # subspace norm squared. Returns k feature maps corresponding to k subspace norms at every pixel location
        return sum(F.conv2d(x,v[:,i],stride=self.stride,padding=self.padding)**2 for i in range(v.shape[1]))
    
#     def init(self, loader):
#         x = next(loader.__iter__())
        
#         # init templates
#         self.v = nn.Parameter(0.001*torch.randn(self.n_subspaces,self.subspace_dim,x.shape[1],self.kernel_size, self.kernel_size, device=x.device))
#         for i in range(self.n_subspaces):
#             ix = np.random.randint(x.shape[0]) 
#             cx = np.random.randint(0,x.shape[2]-self.kernel_size+1)
#             cy = np.random.randint(0,x.shape[3]-self.kernel_size+1)
#             patch = x[ix,:,cx:cx+self.kernel_size, cy:cy+self.kernel_size].clone()
#             patch = patch / patch.norm().clamp(self.eps)
#             self.v.data[i,0] += patch.to(x.device)
            
#         self.v.data = self.svd_project_(self.v.data)
        
    def svd_project_(self, v):
        # apply svd to every kernel
        shape = v.shape[1:]
        out = v.clone()
        for j in range(v.shape[0]):
            try:
                vj = v[j].view(shape[0], -1)
                out[j] = vj.svd()[2].t().reshape(shape)
            except RuntimeError:
                vj = v[j].view(shape[0], -1)
                out[j] = (vj+1e-3*torch.randn_like(vj)).svd()[2].t().reshape(shape)
        return out
    
    def init(self, loader):
        # extract patches
        x = next(loader.__iter__())
        patches = []
        for ix in range(max(x.shape[0],self.n_subspaces)):
            for cx in range(x.shape[2]-self.kernel_size+1):
                for cy in range(x.shape[3]-self.kernel_size+1):
                    patch = x[ix,:,cx:cx+self.kernel_size, cy:cy+self.kernel_size].clone()
                    patch = patch / patch.norm().clamp(self.eps)
                    if patch.norm() != 0:
                        patches.append(patch)
                
        if len(patches) < self.n_subspaces:
            print('warning: have to init some patches with zero')
            for i in range(len(patches), self.n_subspaces):
                patches.append(torch.zeros_like(patch))
                
        patches = torch.stack(patches)
        patches = tovec(patches)
        patches = patches.cpu().numpy()
        
        # print(len(patches))
        centroids, indices = kmeans_plusplus(patches, self.n_subspaces, random_state = np.random.randint(2**32))
        # print(indices)

#         norms = patches.norm(dim=1)
#         norms = norms.cpu().numpy()

#         # km++
#         km = KMeans(n_clusters=self.n_subspaces)
#         centroids = km._init_centroids(patches,norms,init='k-means++',
#                                 random_state=check_random_state(np.random.randint(2**32)),init_size=len(patches))

        # set templates
        self.v = nn.Parameter(0.001*torch.randn(self.n_subspaces,self.subspace_dim,x.shape[1],self.kernel_size, self.kernel_size, device=x.device))
        for i in range(self.n_subspaces):
            patch = torch.from_numpy(centroids[i])
            self.v.data[i,0] = patch.reshape(x.shape[1],self.kernel_size, self.kernel_size).to(x.device)

        # svd
        self.v.data = self.svd_project_(self.v.data)  
    
    def extra_repr(self):
        return 'n_subspaces={n_subspaces}, subspace_dim={subspace_dim}, p={p}, \n'\
                'kernel_size={kernel_size}, padding={padding}, stride={stride}, \n'\
                'rescale_features={rescale_features}, signed_features={signed_features}, random_features={random_features},\n'\
                'n_epochs={n_epochs}, chunks_per_update={chunks_per_update}, warmup_iter={warmup_iter}, \n'\
                'verbose={verbose}, eps={eps}'.format(**self.__dict__)
    
class OnOff(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.cat([x.relu(),(-x).relu()],dim=1)
        
class ZCA(nn.Module):
    def __init__(self, n_components=0, kernel_size=5):
        super().__init__()
        assert(kernel_size % 2 == 1)
        
        self.n_components = n_components
        self.kernel_size = kernel_size
        self.padding = self.kernel_size//2

        # whitening operator
        # self.q = None
        
    def forward(self, x):
        return F.conv2d(F.pad(x,[self.padding]*4,mode='reflect'),self.q,stride=1)

    def fit(self, loader):
        x = next(loader.__iter__())
        
        # compute correlation matrix
        m = x.shape[1]
        v = torch.eye(m*self.kernel_size**2,device=x.device, requires_grad=True)
        v = v.reshape(m*self.kernel_size**2,m,self.kernel_size,self.kernel_size)
        
        e = 1/2*(F.conv2d(x,v,stride=1,padding=0)**2).sum(1).mean()
        c = torch.autograd.grad(e, v)[0].reshape(m*self.kernel_size**2,m*self.kernel_size**2)
                                
        # update project
        u,s,v = c.svd()
        s = s.relu().sqrt()
        self.s_ = s.clone().detach().cpu()
        
        s[:self.n_components] = s[self.n_components] / s[:self.n_components]
        s[self.n_components:] = 1
        q = v @ s.diag() @ v.t()
        
        # get kernel
        q_out = torch.zeros(m,m,self.kernel_size,self.kernel_size,device=q.device,dtype=q.dtype)
        for i in range(m):
            q_ = toim(q,ch=m)[:,i,self.kernel_size//2, self.kernel_size//2]
            q_out[i] = toim(q_.unsqueeze(0),ch=m)[0]
            
        # rescale
        scale = F.conv2d(x,q_out,stride=1,padding=0).std()
        q_out = q_out / scale
        
        # save
        # self.q = q_out
        self.register_buffer('q', q_out)
        
    def extra_repr(self):
        return 'n_components={n_components}, kernel_size={kernel_size}, padding={padding}'.format(**self.__dict__)

class StandardizationModule(nn.Module):
    def __init__(self, center=True):
        super().__init__()
        self.eps = 1e-4
        self.center = center
        # self.mu = None
        # self.sig = None
                
    def forward(self, x):
        if self.center:
            x = (x-self.mu.view(1,-1,1,1))
        return x / self.sig.clamp(self.eps)
    
    def fit(self, loader):
        x = next(loader.__iter__())
        self.register_buffer('mu', x.permute(0,2,3,1).reshape(-1,x.shape[1]).mean(0).detach())
        self.register_buffer('sig', x.permute(0,2,3,1).reshape(-1,x.shape[1]).std(0).mean().detach())
        
class Pool2d(nn.Module):
    def __init__(self, pool_type, kernel_size, stride, padding):
        super().__init__()
        
        if pool_type == 'avg':
            self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
        if pool_type == 'max':
            self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
            
    def forward(self, x):
        return self.pool(x)
        
class StackedInvariantSubspaceModule(nn.Module):     
    def __init__(self, params, verbose=0):
        super().__init__()
        
        self.params = params
        self.verbose = verbose
        
        modules = {}
        if 'input_norm' in params:
            modules['input_norm'] = ZCA(**params['input_norm'])
        
        # modules['on_off'] = OnOff()
        
        for i in range(100):
            layer = 'layer{}'.format(i+1)
            if layer in params:
                modules[layer] = InvariantSubspaceModule(**params[layer])
                
                if layer+'_pool' in params:
                    modules[layer+'_pool'] = Pool2d(**params[layer+'_pool'])
                
                if layer+'_standardize' in params:
                    modules[layer+'_standardize'] = StandardizationModule(**params[layer+'_standardize'])
                
                if layer+'_zca' in params:
                    modules[layer+'_zca'] = ZCA(**params[layer+'_zca'])
                
        if 'output_pool' in params:
            modules['output_pool'] = nn.AdaptiveAvgPool2d(**params['output_pool'])
        
        self.module_dict = nn.ModuleDict(modules)
                            
    def forward(self, x):
        for  module in self.module_dict.values():
            x = module(x)
        return x
    
    def fit(self, loader):
        for i, (module_name, module) in enumerate(self.module_dict.items()):
            # print
            if self.verbose: print('fitting {}'.format(module_name))
            
            # create loader
            class CollateFn(nn.Module):
                def __init__(self, layers):
                    super().__init__()
                    self.layers = layers
                    
                def forward(self, x):
                    with torch.no_grad():
                        x = torch.stack(x,dim=0).cuda()
                        x = self.layers(x)
                    return x
                
            collate_fn = CollateFn(nn.Sequential(*list(self.module_dict.values())[:i]))            
            loader = DataLoader(dataset=loader.dataset, batch_size=loader.batch_size,
                          drop_last=True, shuffle=True, collate_fn = collate_fn)

            # fit 
            try: module.fit(loader)
            except AttributeError: pass
         
# def chunked_inference(module, x, input_device='cuda', output_device='cuda', chunk_size=128):
#     with torch.no_grad():
#         y = []
#         for x_ in torch.split(x,chunk_size):
#             y_ = module(x_.to(input_device))
#             y.append(y_.to(output_device))
#         y = torch.cat(y)
#     return y

def chunked_inference(module, x, input_device='cuda', output_device='cuda', chunk_size=128):
    # allocate output tensor
    shape = [x.shape[0]] + [_ for _ in module(x[:1].to(input_device))[0].shape]
    y = torch.zeros(shape, device=output_device)
    
    # populate values
    i = 0
    with torch.no_grad():
        for x_ in torch.split(x,chunk_size):
            y[i:i+x_.shape[0]] = module(x_.to(input_device)).to(output_device)
            i += x_.shape[0]
    return y

if __name__ == '__main__':
    import psutil
    print('0 mem: {}'.format(psutil.virtual_memory()[3]))


    from torch.utils.data import DataLoader
    from data import cifar10, mnist
    from utils import print_memory_usage
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    print('1 mem: {}'.format(psutil.virtual_memory()[3]))
    x = mnist()[0]# .cuda()
    print('2 mem: {}'.format(psutil.virtual_memory()[3]))
    # x = x[:256*20]
    params = {'input_norm':{}, 
              'layer1': {'n_subspaces': 64, 'subspace_dim': 1, 'kernel_size': 3, 'verbose': 1}
             }
    
    print('3 mem: {}'.format(psutil.virtual_memory()[3]))
    net = StackedInvariantSubspaceModule(params,verbose=1)
    print('4 mem: {}'.format(psutil.virtual_memory()[3]))
    loader = DataLoader(x.cuda(),batch_size=512)
    print('5 mem: {}'.format(psutil.virtual_memory()[3]))
    net.fit(loader)
    
    print('6 mem: {}'.format(psutil.virtual_memory()[3]))
    y = chunked_inference(net, x.cuda(), chunk_size=512, output_device='cpu')   
    print_memory_usage()
    print('7 mem: {}'.format(psutil.virtual_memory()[3]))
    print('mem total: {}'.format(psutil.virtual_memory()))


    
    print(net.module_dict.layer1.v.shape)