import numpy as np
from numpy.fft import fft2, ifft2
import matplotlib.pyplot as plt
from functools import lru_cache
from torch.nn import functional as F
import torch
        


def int_to_bool_list(num,l):
    return [bool(num & (1<<n)) for n in range(l)]

def get_rule_lists(rule_num):
    # converts integer to born, survive list: 16480 -> 000100000, 001100000 ->  [3],[2,3]
    survive = rule_num % 2**9
    bool_born = int_to_bool_list((rule_num - survive) // (2 ** 9), 9)
    bool_born.reverse()
    bool_survive = int_to_bool_list(survive, 9)
    bool_survive.reverse()
    return [i for i, x in enumerate(bool_born) if x == True], [i for i, x in enumerate(bool_survive) if x == True]

def fft_convolve2d(board, fft_kernel):
    board_ft = fft2(board)
    batches, height, width = board_ft.shape
    convolution = np.real(ifft2(board_ft * fft_kernel))
    convolution = np.roll(convolution, - int(height / 2) + 1, axis=-2)
    convolution = np.roll(convolution, - int(width / 2) + 1, axis=-1)
    return convolution.round()




class LifeLikeAutomaton2D():
    def __init__(self,rule_num, grid_size, init_function=None):
        self.rule_num = rule_num
        self.rule = rule_num
        self.born_list, self.survive_list = get_rule_lists(rule_num)
        self.grid_size = grid_size
        self.init_function = init_function

        kernel = np.zeros((grid_size,grid_size))
        kernel[(grid_size - 3 - 1) // 2: (grid_size + 3) // 2,(grid_size - 3 - 1) // 2: (grid_size + 3) // 2] = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]])
        self.fft_kernel =fft2(kernel)


    def apply_rule(self, grid):
        convolution = fft_convolve2d(grid, self.fft_kernel)
        shape = convolution.shape
        new_grid = np.zeros(shape)
        new_grid[np.where(np.in1d(convolution, self.survive_list).reshape(shape)
                           & (grid == 1))] = 1
        new_grid[np.where(np.in1d(convolution, self.born_list).reshape(shape)
                           & (grid == 0))] = 1

        return new_grid
    

    def _forward(self,grid, iterations, naturalize_iterations=0):
        for i in range(naturalize_iterations):
            grid = self.apply_rule(grid)

        sample = grid
        intermediate = []

        for i in range(iterations):
            intermediate.append(grid)
            grid = self.apply_rule(grid)

        target = grid

        return sample, target, intermediate

    def get_sample(self, density, iterations, naturalize_iterations, batch_size=16, grid_size=None):
        
        assert grid_size is None or grid_size <= self.grid_size
        
        if self.init_function is None:
            grid = np.random.uniform(0, 1, (batch_size, self.grid_size, self.grid_size))
            grid = grid < density       
        else:
            grid = self.init_function(self.grid_size, density, batch_size)


        for i in range(naturalize_iterations):
            grid = self.apply_rule(grid)

        sample = grid

        intermediate = []

        for i in range(iterations):
            intermediate.append(grid)
            grid = self.apply_rule(grid)

        target = grid
        if grid_size!=None and grid_size<self.grid_size:
            sample = sample[:, :grid_size, :grid_size]
            target = target[:, :grid_size, :grid_size]
            intermediate = [i[:, :grid_size, :grid_size] for i in intermediate]

        return sample, target, intermediate

    def calc_lambda(self):
        #add line to make this dynamically computed with cache
        @lru_cache(maxsize=1000)
        def binomial_coeff(n, k):
            if k > n:
                return 0
            if k == 0 or k == n:
                return 1
            return binomial_coeff(n - 1, k - 1) + binomial_coeff(n - 1, k)

        
    
        ca_lambda = 0

        for i in self.born_list:
            ca_lambda+=0.5*binomial_coeff(8,i)*(0.5)**8
        for i in self.survive_list:
            ca_lambda+=0.5*binomial_coeff(8,i)*(0.5)**8

        return ca_lambda

    def calc_perturbation_sensitivity(self,iterations, num_samples=64):
        
        assert iterations*2+1<=self.grid_size
        
        density = 0.5
        grid = np.random.uniform(0, 1, (num_samples,self.grid_size,self.grid_size))
        grid = grid < density
        grid2 = np.copy(grid)
        grid2[:,iterations,iterations] = 1-grid2[:,iterations,iterations]

        for i in range(iterations):
            grid = self.apply_rule(grid)
            grid2 = self.apply_rule(grid2)
        
        

        
        return np.sum(grid!=grid2)/(num_samples*( iterations*2+1)*(iterations*2+1))

    def calc_rep_lambda(self,num_samples=64, iterations=32):
        density = 0.5
        grid = np.random.uniform(0, 1, (num_samples, self.grid_size, self.grid_size))
        grid = grid < density
        
        for i in range(iterations):
            grid = self.apply_rule(grid) 
        
        return np.mean(grid)

import random
def init_512_rule(uniform_lambda = False):
    # either init rule completely randomly
    # which will lead to the majority of rules generated
    # being around lambda 0.5
    # or draw lambda randomly form [0,1]
    if uniform_lambda:
        _lambda = np.random.rand()
        rule = []
        for i in range(512):
            if _lambda > np.random.rand():
                rule.append(1)
            else:
                rule.append(0)
        return rule
    else:
        return [random.randint(0,1) for i in range(512)]
    
def init_big_stuff(filter_size,uniform_lambda = False):
    size = filter_size[0]*filter_size[1]
    if uniform_lambda:
        _lambda = np.random.rand()
        rule = []
        for i in range(2**size):
            if _lambda > np.random.rand():
                rule.append(1)
            else:
                rule.append(0)
        return rule
    else:
        return [random.randint(0,1) for i in range(2**size)]


# handles all possible 2d rules
class CellularAutomaton2D:
    def __init__(self, rule, grid_size):
        self.rule = torch.LongTensor(rule)
        self.grid_size = grid_size
        self.pot = np.array([2**(8-i) for i in range(9)])
        self.weights = self._create_conv_filter()


    def _create_conv_filter(self):
        import torch
        #creates weights for a 3x3 conv filter
        #that can be used to apply the rule
        #to a 2d grid

        weights = np.zeros((1,1,3,3))
        for i in range(9):
            weights[0][0][i//3][i%3] = self.pot[i]

        return torch.from_numpy(weights).float()



    def apply_rule(self, grid):
        # 3d ndarray of bools or ints shape: batch_size,10,10
        grid = torch.Tensor(grid)
        expanded_padding = (1,1,1,1)
        res = F.conv2d(F.pad(grid.unsqueeze(1), expanded_padding, mode='circular'), self.weights).long()
        out = torch.index_select(self.rule, 0, res.view(-1)).view(grid.shape)

        
        return out.numpy().astype(bool)
    
    def _forward(self,grid, iterations, naturalize_iterations=0):
        for i in range(naturalize_iterations):
            grid = self.apply_rule(grid)

        sample = grid
        intermediate = []

        for i in range(iterations):
            intermediate.append(grid)
            grid = self.apply_rule(grid)

        target = grid

        return sample, target, intermediate


    def get_sample(self, density, iterations, naturalize_iterations, batch_size=16, grid_size=None):
        if grid_size is not None:
            grid_size = self.grid_size
        grid = np.random.uniform(0, 1, (batch_size, grid_size, grid_size))
        grid = grid < density

        return self._forward(grid, iterations, naturalize_iterations)
    

    def calc_lambda(self):
        return sum(self.rule)/512

    def calc_perturbation_sensitivity(self,iterations, num_samples=64):
        
        assert iterations*2+1<=self.grid_size
        
        density = 0.5
        grid = np.random.uniform(0, 1, (num_samples,self.grid_size,self.grid_size))
        grid = grid < density
        grid2 = np.copy(grid)
        grid2[:,iterations,iterations] = 1-grid2[:,iterations,iterations]

        for i in range(iterations):
            grid = self.apply_rule(grid)
            grid2 = self.apply_rule(grid2)
        
        

        
        return np.sum(grid!=grid2)/(num_samples*( iterations*2+1)*(iterations*2+1))



class CoarseWrapper():

    def __init__(self, automaton, time_factor, spatial_factor, only_output_coarse=False, init_function=None, **kwargs):
        self.automaton = automaton
        self.time_factor = time_factor
        self.spatial_factor = spatial_factor
        self.grid_size = automaton.grid_size
        self.rule = automaton.rule
        self.perturbation_sensitivity = automaton.calc_perturbation_sensitivity(self.time_factor,256).item()
        self.only_output_coarse = only_output_coarse
        self.init_function = init_function
        if init_function is not None:
            self.automaton.init_function = init_function

        self.additional_args = kwargs
    
    def get_config(self):
        d = {'time_factor': self.time_factor, 'spatial_factor': self.spatial_factor}
        for k, v in self.additional_args.items():
            d[k] = v
        d["grid_size"] = self.grid_size
        d["rule"] = self.rule
        d["perturbation_sensitivity"] = self.perturbation_sensitivity
        d["coarse_sensitivity"] = self.calc_perturbation_sensitivity_coarse(256).item()
        d["rule_type"] = "LifeLike" if hasattr(self.automaton, "born_list") else "General"
        d["only_output_coarse"] = self.only_output_coarse
        if self.init_function is not None:
            d["init_function"] = self.init_function.__name__
        return d
    
    def spatialize(self, x, y):
        if not self.only_output_coarse:
            x = np.pad(x, ((0,0),(0,self.spatial_factor-1),(0,self.spatial_factor-1)),mode="wrap")
            x = F.avg_pool2d(torch.tensor(x, dtype=torch.float), kernel_size=(self.spatial_factor,self.spatial_factor), stride=1)
            x= (x>0.5)*1.0
        
        l_pad = self.spatial_factor//2
        r_pad = self.spatial_factor//2
        if self.spatial_factor%2==0:
            l_pad = self.spatial_factor//2-1
        
        y = np.pad(y, ((0,0),(l_pad,r_pad),(l_pad,r_pad)), mode="wrap")
        y = F.avg_pool2d(torch.tensor(y, dtype=torch.float), kernel_size=(self.spatial_factor,self.spatial_factor), stride=1)
        y= (y>0.5)*1.0

        return x,y



    def _forward_sample(self, x):

        x,y,intermediate = self.automaton._forward(x, self.time_factor, 0)

        x,y = self.spatialize(x,y)

        return x,y



    def get_batch(self, batchsize, naturalize_iterations=0):
        
        x,y,intermediate = self.automaton.get_sample(0.5, self.time_factor, naturalize_iterations, batchsize, self.grid_size)


        if self.spatial_factor>1:
            x,y = self.spatialize(x,y)



        x = F.one_hot(torch.tensor(x, dtype=torch.long)).permute(0,3,1,2).float()
        y = torch.tensor(y, dtype=torch.long)#F.one_hot(torch.tensor(y, dtype=torch.long)).permute(0,3,1,2).float()
        return x,y
    
    def calc_perturbation_sensitivity_coarse(self, num_samples=64):


        density = 0.5
        grid = np.random.uniform(0, 1, (num_samples,self.grid_size,self.grid_size))
        grid = grid < density
        grid2 = np.copy(grid)
        grid2[:,self.time_factor,self.time_factor] = 1-grid2[:,self.time_factor,self.time_factor]

        in1, out1 = self._forward_sample(grid)
        in2, out2 = self._forward_sample(grid2)

        if self.only_output_coarse:
            #calc diff between out1 and out2
            #print(out1.shape)
            diff = (out1!=out2).sum()

            pert_sens = diff/(num_samples*(self.time_factor*2+1)*(self.time_factor*2+1))
            return pert_sens
        else:

            assert False, "not implemented coarse to coarse perturbation sensitivity yet"
        
        
        

def create_init_only_center(center_size):
    def init_only_center(grid_size, density, batch_size):
        #init only the center of the grid
        grid = np.zeros((batch_size, grid_size, grid_size))
        if grid_size%2==0 and center_size%2==0:
            patch_size = center_size//2
            grid[:,
                 grid_size//2-patch_size:grid_size//2+patch_size,
                 grid_size//2-patch_size:grid_size//2+patch_size] = np.random.uniform(0, 1, (batch_size,(center_size),center_size))
        elif grid_size%2==1 and center_size%2==1:
            middle = grid_size//2
            patch_size = center_size//2
            grid[:,
                 middle-patch_size:middle+patch_size+1,
                 middle-patch_size:middle+patch_size+1] = np.random.uniform(0, 1, (batch_size,(center_size),center_size))
        else:
            assert False, "grid size and center size must be even or odd together"
        
        grid = grid < density
        return grid
    
    return init_only_center

if __name__=="__main__":

    automaton = LifeLikeAutomaton2D(110, 32)
    CA = LifeLikeAutomaton2D(int('000100000001100000', 2),32)
    import random
    CA = LifeLikeAutomaton2D(random.randint(0,2**18),30)
    
    coarse = CoarseWrapper(CA, 10, 2, only_output_coarse=True, init_function=create_init_only_center(8))
    x,y = (coarse.get_batch(1))

    import matplotlib.pyplot as plt
    fig,axs = plt.subplots(1,2)
    axs[0].imshow(x[0][1][:, :], cmap='gray')
    axs[1].imshow(y[0][:, :], cmap='gray')
    plt.show()

    print(x.shape)
    print(y.shape)
    print(coarse.get_config())

    print(x[0][1][:10,:10])
    print(y[0][:10,:10])

    print(coarse.perturbation_sensitivity)
    print(coarse.calc_perturbation_sensitivity_coarse(256))