
import torch
import torch.nn as nn
from torch.nn import Conv2d

import cv2
import tables
import numpy as np
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence

import torch.nn.functional as F
from torch.utils.data import Dataset
#import lmdb
#from PIL import Image
#import pyarrow as pa
import os.path as osp
#import six
from skimage.measure import block_reduce

import time






class ApplyPooling_GPU(object):
    def __init__(self):
        self.flag_enfovea = 1
        self.flag_enPM = 1

        folder_name = 'Scale_0.84/'
        try:
            if self.flag_enPM:
                #self.normalized_data = torch.tensor(np.load('Pooling_Regions/Pooling_Masks_data.npy')).unsqueeze(4).to('cuda:0').type(torch.cuda.FloatTensor)
                #self.pooling_limits = torch.tensor(np.load('Pooling_Regions/Pooling_Masks_limits.npy')).to('cuda:0')
                self.normalized_data = torch.tensor(np.load('Pooling_Regions/'+folder_name+'Pooling_Masks_data.npy')).unsqueeze(4).cuda().type(torch.cuda.FloatTensor)
                self.pooling_limits = torch.tensor(np.load('Pooling_Regions/'+folder_name+'Pooling_Masks_limits.npy')).cuda() #to('cuda:0')
            if self.flag_enfovea:
                #self.fnormalized_data = torch.tensor(np.load('Pooling_Regions/Foveal_Mask_data.npy')).unsqueeze(3).to('cuda:0').type(torch.cuda.FloatTensor)
                #self.foveal_limits = torch.tensor(np.load('Pooling_Regions/Foveal_Mask_limits.npy')).to('cuda:0')
                self.fnormalized_data = torch.tensor(np.load('Pooling_Regions/'+folder_name+'Foveal_Mask_data.npy')).unsqueeze(3).cuda().type(torch.cuda.FloatTensor)
                self.foveal_limits = torch.tensor(np.load('Pooling_Regions/'+folder_name+'Foveal_Mask_limits.npy')).cuda() #.to('cuda:0')
                self.flim = self.foveal_limits[0][0].type(torch.int).cpu().numpy()
        except:
            print('**************** I did not find the pooling regions ***************************')
            SC = 16 #4 #16
            if self.flag_enPM:
                pooling_data = tables.open_file('Pooling_Regions/'+folder_name+'Pooling_Masks_data.mat').root.Pooling_Masks
                # Each pooling mask is of size (1350, 1350)
                # Compute the downsampled version which can be used for pooling
                data_shape = np.shape(pooling_data) #.root.Pooling_Masks)
                print(data_shape)
                self.normalized_data = []
                lim = np.zeros((data_shape[0], data_shape[1], 4))
                for i in range(data_shape[0]):
                    row_data = []
                    #print('AJ: j')
                    for j in range(data_shape[1]):
                        print((i,j))
                        start = time.time() #%tic;
                        #pm = pooling_data.root.Pooling_Masks[i][j][0]
                        pm = pooling_data[i][j][0]
                        print('Time taken#0: {}'.format(time.time()-start))#toc;
                        start = time.time() #%tic;
                        pm_ = block_reduce(pm, block_size=(SC,SC), func=np.sum)
                        print('Time taken: {}'.format(time.time()-start))#toc;
                        #pm_ = pm_/(0.5*np.max(pm_))
                        #pm_ = np.floor(np.clip(pm_,0,1))
                        start = time.time() #%tic;
                        pm_ = pm_/np.sum(pm_)
                        limits_ = np.nonzero(pm_)
                        lim[i][j] = [np.min(limits_[0]), np.max(limits_[0]), np.min(limits_[1]), np.max(limits_[1])]
                        print('Time taken#2: {}'.format(time.time()-start))#toc;
                        print(lim[i][j]) 
                        row_data.append(pm_)
                    self.normalized_data.append(row_data)
                np.save('Pooling_Regions/'+folder_name+'Pooling_Masks_limits.npy', lim)
                np.save('Pooling_Regions/'+folder_name+'Pooling_Masks_data.npy', self.normalized_data)
            if self.flag_enfovea:
                lim = np.zeros((1,1,4)) #[top, bottom, left, right]
                foveal_data = tables.open_file('Pooling_Regions/'+folder_name+'foveal_Mask_data.mat')
                self.fnormalized_data = []
                pm = foveal_data.root.foveal_Mask[0][0][0]
                pm_ = block_reduce(pm, block_size=(SC,SC), func=np.sum)  
                pm_ = pm_/np.max(pm_) #16 #234 #128
                pm_ = np.floor(np.clip(pm_,0,1))
                limits_ = np.nonzero(pm_)
                lim[0][0] = [np.min(limits_[0]), np.max(limits_[0]), np.min(limits_[1]), np.max(limits_[1])]
                print('***** AJ: fovea limits: {}'.format(lim[0][0]))
                self.fnormalized_data.append(pm_)
                np.save('Pooling_Regions/'+folder_name+'Foveal_Mask_limits.npy', lim)
                np.save('Pooling_Regions/'+folder_name+'Foveal_Mask_data.npy', self.fnormalized_data)
        if self.flag_enPM:
            self.mask_shape = np.shape(self.normalized_data[0][0])
            self.f_y = int(self.mask_shape[0]/2)
            self.f_x = int(self.mask_shape[1]/2)
            self.data_shape = np.shape(self.normalized_data)
        if self.flag_enfovea:
            self.mask_shape = np.shape(self.fnormalized_data[0])
            self.f_y = int(self.mask_shape[0]/2)
            self.f_x = int(self.mask_shape[1]/2)
            self.fdata_shape = np.shape(self.fnormalized_data)

        self.flag_debug = False
 
    def __call__(self, feat, x_shift, y_shift, size_xy):
        
        # (left, right, top, bottom)
        #feat = F.pad(input=feat, pad=(0, 0, self.f_x-x_shift, self.mask_shape[1]-self.f_x+x_shift-64, self.f_y-y_shift, self.mask_shape[0]-self.f_y+y_shift-64), mode='constant', value=0) #.to_sparse()
        if self.flag_debug:
            print('self.pooling_limits: {} ({})'.format(self.pooling_limits, self.pooling_limits.shape)) 
            print('self.foveal_limits: {}'.format(self.foveal_limits))

            print('self.mask_shape: {}'.format(self.mask_shape))

        x_max = self.mask_shape[0]    
        y_max = self.mask_shape[1]    

        if (self.f_x-x_shift)<0:
            left = 0
            left_pad = x_shift-self.f_x
        else:
            left = self.f_x - x_shift
            left_pad = 0
            
        if (self.f_x-x_shift+size_xy)>x_max:
            right = x_max-1
            right_pad = self.f_x-x_shift+size_xy-x_max+1
        else:
            right = self.f_x - x_shift + size_xy
            right_pad = 0

        if (self.f_y-y_shift)<0:
            top = 0
            top_pad = y_shift-self.f_y
        else:
            top = self.f_y - y_shift
            top_pad = 0
            
        if (self.f_y-y_shift+size_xy)>y_max:
            bottom = y_max-1
            bottom_pad = self.f_y-y_shift+size_xy-y_max+1
        else:
            bottom = self.f_y - y_shift + size_xy
            bottom_pad = 0

        if self.flag_debug:
            print([left, right, top, bottom])
        #print([left_pad, right_pad, top_pad, bottom_pad])



        if self.flag_enPM:
            if self.flag_debug:
                print('AJ: Debug 10')
                print(self.normalized_data.shape)
            #print([(self.f_x-x_shift), (self.mask_shape[1]-self.f_x+x_shift-size_xy), (self.f_y-y_shift), (self.mask_shape[0]-self.f_y+y_shift-size_xy)])
            #print([(self.f_x-x_shift), (self.f_x-x_shift+size_xy), (self.f_y-y_shift), (self.f_y-y_shift+size_xy)])
            #normalized_data = self.normalized_data[:,:,(self.f_x-x_shift):(self.f_x-x_shift+size_xy),(self.f_y-y_shift):(self.f_y-y_shift+size_xy),:]
            

            #normalized_data = self.normalized_data[:,:,left:right, top:bottom, :]
            #normalized_data = F.pad(input=normalized_data, pad=(0, 0, top_pad, bottom_pad, left_pad, right_pad), mode='constant', value=0) 
            normalized_data = self.normalized_data[:,:,top:bottom, left:right, :]
            normalized_data = F.pad(input=normalized_data, pad=(0, 0, left_pad, right_pad, top_pad, bottom_pad), mode='constant', value=0) 


            if self.flag_debug:
                print('normalized.shape: {}'.format(normalized_data.shape))
                print('feat.shape: {}'.format(feat.shape))
            #pooled_feat = torch.sparse.mm(self.normalized_data, feat)
            #print(pooled_feat.shape)
            pooled_feat = torch.sum(normalized_data * feat, dim=[2,3])
            if self.flag_debug:
                print('pooled_feat[:,:,0]: {}'.format(pooled_feat[:,:,0]))
            pf_shape = pooled_feat.shape
            if self.flag_debug:
                print('pooled_feat.shape: {}'.format(pf_shape))
            pooled_feat = pooled_feat.reshape(-1, pf_shape[2])
            pooling_limits = self.pooling_limits.reshape(-1, 4)
            if self.flag_debug:
                print('pooled_feat.shape: {}, pooling_limits.shape: {}'.format(pooled_feat.shape, pooling_limits.shape))
            #z[torch.nonzero(z[:,0]),:]
            valid_idx = torch.nonzero(pooled_feat[:,0])
            pooled_feat = pooled_feat[valid_idx, :].squeeze(1)
            pooling_limits = pooling_limits[valid_idx, :].squeeze(1)
            if self.flag_debug:
                print('pooled_feat.shape: {}, pooling_limits.shape: {}'.format(pooled_feat.shape, pooling_limits.shape))
        if self.flag_enfovea:
            #self.fnormalized_data = self.fnormalized_data.type(torch.cuda.FloatTensor)
            #print(self.fnormalized_data.shape)
            #fnormalized_data = self.fnormalized_data[:,(self.f_x-x_shift):(self.f_x-x_shift+size_xy),(self.f_y-y_shift):(self.f_y-y_shift+size_xy),:]
            
            #fnormalized_data = self.fnormalized_data[:,left:right, top:bottom, :]
            #fnormalized_data = F.pad(input=fnormalized_data, pad=(0, 0, top_pad, bottom_pad, left_pad, right_pad), mode='constant', value=0) 
            fnormalized_data = self.fnormalized_data[:,top:bottom, left:right, :]
            fnormalized_data = F.pad(input=fnormalized_data, pad=(0, 0, left_pad, right_pad, top_pad, bottom_pad), mode='constant', value=0) 
            
            foveal_feat = fnormalized_data * feat
            foveal_feat = F.pad(input=foveal_feat, pad=(0, 0, self.f_x-x_shift, self.mask_shape[1]-self.f_x+x_shift-size_xy, self.f_y-y_shift, self.mask_shape[0]-self.f_y+y_shift-size_xy), mode='constant', value=0) #.to_sparse()
            flim = self.foveal_limits[0][0].type(torch.int)
            #foveal_feat = foveal_feat[:, flim[0]:(flim[1]+1), flim[2]:(flim[3]+1),:].squeeze(0)
            foveal_feat = foveal_feat[:, self.flim[0]:(self.flim[1]+1), self.flim[2]:(self.flim[3]+1),:].squeeze(0)            
            if self.flag_debug:
                print('foveal_feat[:,:,0]: {}'.format(foveal_feat[:,:,0]))
            foveal_feat = foveal_feat.reshape(-1, pf_shape[2])
            if self.flag_debug:
                print('foveal_feat.shape: {}'.format(foveal_feat.shape))


        '''
        if (self.flag_enfovea and self.flag_enPM):
            #pooled_feat_ = F.pad(pooled_feat, (0,0,0,0,3,3), 'constant') 
            #foveal_feat_ = F.pad(foveal_feat, (0,0,2,2,0,0), 'constant')
            #print('foveal_feat_.shape: {}'.format(foveal_feat_.shape))
            #feat_global = torch.cat((foveal_feat_, pooled_feat), dim = 0)
            foveal_feat_ = F.pad(foveal_feat, (0,0,0,0,1,1), 'constant')
            print('foveal_feat_.shape: {}'.format(foveal_feat_.shape))
            feat_global = torch.cat((foveal_feat_, pooled_feat), dim = 1)
                
        if (self.flag_enfovea and not self.flag_enPM):
            feat_global = foveal_feat

        if (not self.flag_enfovea and self.flag_enPM):
            feat_global = pooled_feat
       
        print('feat_global.shape: {}'.format(feat_global.shape))
        return feat_global
        '''

        return (foveal_feat, pooled_feat, pooling_limits)





if __name__=='__main__':
    apply_pooling = ApplyPooling_GPU()
    input_image = torch.rand(1,56,56,12).cuda()

    max_lim = 14 #24 #35 #49

    p_feat_zeros = torch.zeros((max_lim,12)).cuda()
    p_limits_zeros = torch.zeros((max_lim,4)).cuda()

    len_list = []
    for fix_x in range(14):
        for fix_y in range(14):
            (foveal_feat, pooled_feat, pooling_limits) = apply_pooling(input_image[0], fix_x*4+1, fix_y*4+1, input_image.shape[2])
            print(pooled_feat.shape)
            #pooled_feat = torch.cat((pooled_feat,p_feat_zeros[0:(max_lim-pooled_feat.shape[0])]), 0)
            #pooling_limits = torch.cat((pooling_limits,p_limits_zeros[0:(max_lim-pooling_limits.shape[0])]), 0)
            len_list.append(pooled_feat.shape[0])

    print('\n\n')
    print('len_list: {} (max - {})'.format(len_list, np.max(len_list)))
