import torch
import torch.nn as nn
import torch.nn.functional as F
import os

import numpy as np
from neuron_model import *
import loss_f


connection_types = [
    'n', # none
    'e', # excitatory connection
    'i' # inhibitory connection
    ]

cell_sizes = [
    2,    
    4,
    8,
    16,
    25
    ]
    
def linear_f(x, w, b=None):
    x = x.view(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3], x.shape[4])
    x = x.transpose(1, 2)
    y = F.linear(x, w, b)
    y = y.transpose(1, 2)
    y = y.view(y.shape[0], y.shape[1], 1, 1, y.shape[2])
    return y

class Recurrent_Layer(nn.Linear):

    def __init__(self, network_config, layer_config, name):
        self.name = name
        self.n_inputs = layer_config['n_inputs']
        self.n_outputs = layer_config['n_outputs']
        self.layer_config = layer_config
        self.network_config = network_config
        self.arch_batch_size = network_config['arch_batch_size']
        self.batch_size = network_config['batch_size']
        self.n_steps = network_config['n_steps']
        self.n_class = network_config['n_class']
        self.type = layer_config['type']

        self.fname = 'arch_' + name + '.txt'

        super(Recurrent_Layer, self).__init__(self.n_inputs, self.n_outputs, bias=False)

        self.mode = 'None'
        
        # search cell size
        self.weight_e_list = []
        self.lateral_weight_e_list = []
        self.arch_weight_list = []
        self.lateral_arch_weight_list = []
        self.cell_size_weight = None
        self.best_cell_size_weight = None
        self.best_cell_size_ind = None

        self.best_lateral_connection_type_weight = None

        # search connection type
        self.connection_type_weight = None
        self.lateral_connection_type_weight = None
        self.weight_e = None
        self.lateral_weight_e = None
        self.best_arch = None
        self.best_lateral_arch = None
        self.final_cs = None
        self.final_n_cell = None

        # finetune
        self.final_prob_e = None
        self.final_prob_i = None
        self.final_lateral_prob_e = None
        self.final_lateral_prob_i = None

        # init input weights
        self.weight = nn.Parameter(self.weight.cuda(), requires_grad=True)
        nn.init.kaiming_normal_(self.weight)

        try:
            self.w_e = self.layer_config['w_e']
            self.w_i = self.layer_config['w_i']
        except:
            self.w_e = 0
            self.w_i = 0

    def init_weights(self):
        if self.mode == "search_cell_size":
            for cs in cell_sizes:
                assert self.n_outputs % cs == 0
                n_cell = int(self.n_outputs/cs)
                self.weight_e_list.append(nn.Parameter(self.w_e * torch.ones(n_cell, cs, cs).cuda(), requires_grad=True))
                self.lateral_weight_e_list.append(nn.Parameter(self.w_e * torch.ones(1, self.n_outputs).cuda(), requires_grad=True))

                # weight is the probability for NAS model
                self.arch_weight_list.append(nn.Parameter(torch.randn(cs, cs, len(connection_types)).cuda() * 0.001, requires_grad=True))
                self.lateral_arch_weight_list.append(nn.Parameter(torch.randn(1, self.n_outputs, len(connection_types)).cuda() * 0.001, requires_grad=True))
            self.cell_size_weight = nn.Parameter(torch.zeros(len(cell_sizes)).cuda(), requires_grad=True)
        elif self.mode == "search_connection_type":
            self.final_cs = cell_sizes[self.best_cell_size_ind]
            self.final_n_cell = int(self.n_outputs/self.final_cs)

            self.connection_type_weight = nn.Parameter(torch.randn(self.final_cs, self.final_cs, len(connection_types)).cuda() * 0.001, requires_grad=True)
            self.lateral_connection_type_weight = nn.Parameter(torch.randn(1, self.n_outputs, len(connection_types)).cuda() * 0.001, requires_grad=True)
            self.weight_e = nn.Parameter(self.w_e * torch.ones(self.final_n_cell, self.final_cs, self.final_cs).cuda(), requires_grad=True)
            self.lateral_weight_e = nn.Parameter(self.w_e * torch.ones(1, self.n_outputs).cuda(), requires_grad=True)

            if self.best_lateral_connection_type_weight is not None:
                self.lateral_connection_type_weight.data.copy_(self.best_lateral_connection_type_weight.data)
                self.weight_e.data.copy_(self.best_weight_e.data)
                self.connection_type_weight.data.copy_(self.best_connection_type_weight.data)
                self.lateral_weight_e.data.copy_(self.best_lateral_weight_e.data)

            del self.weight_e_list
            del self.lateral_weight_e_list
            del self.arch_weight_list
            del self.lateral_arch_weight_list
            del self.cell_size_weight
        
        elif self.mode == "finetune":
            nn.init.kaiming_normal_(self.weight)
            self.weight_e = nn.Parameter(self.w_e * torch.ones(self.final_n_cell, self.final_cs, self.final_cs).cuda(), requires_grad=True)
            self.lateral_weight_e = nn.Parameter(self.w_e * torch.ones(1, self.n_outputs).cuda(), requires_grad=True)
            probs = self.get_prob(self.best_arch)
            connections = torch.argmax(probs, dim=-1)
        
            self.final_prob_e = torch.zeros_like(probs[..., 1])
            self.final_prob_i = torch.zeros_like(probs[..., 2])
            self.final_prob_e[connections == 1] = 1
            self.final_prob_i[connections == 2] = 1

            probs = self.get_prob(self.best_lateral_arch)
            connections = torch.argmax(probs, dim=-1)
        
            self.final_lateral_prob_e = torch.zeros_like(probs[..., 1])
            self.final_lateral_prob_i = torch.zeros_like(probs[..., 2])
            self.final_lateral_prob_e[connections == 1] = 1
            self.final_lateral_prob_i[connections == 2] = 1

        else:
            raise Exception('Unrecognized mode.')


    def get_parameters(self):
        params = [self.weight]
        if self.mode == "search_cell_size":
            for p in self.weight_e_list:
                params.append(p)
            for p in self.lateral_weight_e_list:
                params.append(p)
        elif self.mode == "search_connection_type":
            params.append(self.weight_e)
            params.append(self.lateral_weight_e)
        elif self.mode == "finetune":
            params.append(self.weight_e)
            params.append(self.lateral_weight_e)
        else:
            raise Exception('Unrecognized mode.')
        return params

    def get_arch_parameters(self):
        params = []
        if self.mode == "search_cell_size":
            params.append(self.cell_size_weight)
            for p in self.arch_weight_list:
                params.append(p)
            for p in self.lateral_arch_weight_list:
                params.append(p)
        elif self.mode == "search_connection_type":
            params.append(self.connection_type_weight)
            params.append(self.lateral_connection_type_weight)
        else:
            raise Exception('Unrecognized mode.')
        return params

    def get_other_variables(self):
        return [self.final_cs, self.final_n_cell, self.best_cell_size_ind]

    def set_other_variables(self, params):
        self.final_cs = params[0]
        self.final_n_cell = params[1]
        self.best_cell_size_ind = params[2]

    def forward(self, x):
        y = linear_f(x, self.weight, self.bias)
        shape = y.shape
        n_steps = shape[-1]
       
        if self.mode == "search_cell_size":
            self.probs_cell_size = self.get_prob(self.cell_size_weight)

        mem_pre = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
        syns = torch.zeros(shape[0], shape[1], shape[2], shape[3], self.n_steps).cuda()
        mem = torch.zeros_like(mem_pre)
        response = torch.zeros_like(mem_pre)
        outputs = torch.zeros_like(syns)
        theta = torch.ones_like(mem_pre) * (self.network_config['tau_m'])
        R = torch.ones_like(mem_pre) * (self.network_config['tau_m'])
        cal = torch.zeros_like(mem_pre)

        for t in range(n_steps):
            y, I = Recurrent_Dendrite.apply(y, t)
            mem = (1-1/theta) * mem_pre + R * I / theta
       
            if self.mode == "search_cell_size":
                for cs, weight_e, lateral_weight_e, arch_weight, lateral_arch_weight, ind in zip(cell_sizes, self.weight_e_list, self.lateral_weight_e_list, self.arch_weight_list, self.lateral_arch_weight_list, range(len(cell_sizes))):
                    n_cell = int(self.n_outputs/cs)
                    probs = self.get_prob(arch_weight)
                    weight_final = weight_e * probs[..., 1] + self.w_i * probs[..., 2]

                    mem += self.probs_cell_size[ind] * self.recurrent_f(response, weight_final, n_cell, cs)

                    probs = self.get_prob(lateral_arch_weight)
                    weight_lateral = probs[..., 1] * lateral_weight_e + probs[..., 2] * self.w_i 

                    tmp = response * weight_lateral[0, ...].view(-1, 1, 1) 
                    mem[:, cs:, ...] += tmp[:, 0:-cs, ...]

                syns, outputs, response, mem_pre, theta, cal, R = Recurrent_LIF.apply(syns, outputs, mem, mem_pre, theta, cal, R, self.network_config, self.layer_config, t, True)
                     
            elif self.mode == "search_connection_type":
                probs = self.get_prob(self.connection_type_weight)
                weight_final = self.weight_e * probs[..., 1] + self.w_i * probs[..., 2]

                mem += self.recurrent_f(response, weight_final, self.final_n_cell, self.final_cs)

                probs = self.get_prob(self.lateral_connection_type_weight)

                weight_lateral = probs[..., 1] * self.lateral_weight_e + probs[..., 2] * self.w_i 

                tmp = response * weight_lateral[0, ...].view(-1, 1, 1)
                mem[:, self.final_cs:, ...] += tmp[:, 0:-self.final_cs, ...]
                syns, outputs, response, mem_pre, theta, cal, R = Recurrent_LIF.apply(syns, outputs, mem, mem_pre, theta, cal, R, self.network_config, self.layer_config, t, True)
                     
            elif self.mode == "finetune":
                weight_final = self.weight_e * self.final_prob_e + self.w_i * self.final_prob_i
                mem += self.recurrent_f(response, weight_final, self.final_n_cell, self.final_cs)

                weight_lateral = self.final_lateral_prob_e * self.lateral_weight_e + self.final_lateral_prob_i * self.w_i 

                tmp = response * weight_lateral[0, ...].view(-1, 1, 1) 
                mem[:, self.final_cs:, ...] += tmp[:, 0:-self.final_cs, ...]
                     
                syns, outputs, response, mem_pre, theta, cal, R = Recurrent_LIF.apply(syns, outputs, mem, mem_pre, theta, cal, R, self.network_config, self.layer_config, t, False)
            else:
                raise Exception('Set mode before training.')
        return syns

    def recurrent_f(self, inputs, weight_r, n_cell, cell_size):
        shape = inputs.shape
        w_shape = weight_r.shape
        x = inputs.view(shape[0], n_cell, 1, cell_size)
        y = torch.matmul(x, weight_r)
        return y.view(shape)

    def set_mode(self, m):
        self.mode = m

    def get_mode(self):
        return self.mode

    def get_prob(self, arch_weight):
        probs = F.softmax(arch_weight, dim=-1)
        return probs

    def get_outputs(self):
        return self.syns
        
    # record and load
    def record_best_arch(self):
        if self.mode == "search_cell_size":
            self.best_cell_size_ind = torch.argmax(self.cell_size_weight)
            self.best_cell_size_weight = self.cell_size_weight
            self.best_connection_type_weight = self.arch_weight_list[self.best_cell_size_ind].clone()
            self.best_lateral_connection_type_weight = self.lateral_arch_weight_list[self.best_cell_size_ind].clone()
            self.best_weight_e = self.weight_e_list[self.best_cell_size_ind].clone()
            self.best_lateral_weight_e = self.lateral_weight_e_list[self.best_cell_size_ind].clone()
        if self.mode == "search_connection_type":
            self.best_arch = self.connection_type_weight.clone()
            self.best_lateral_arch = self.lateral_connection_type_weight.clone()


class Feedforward_Layer(nn.Linear):

    def __init__(self, network_config, layer_config, name):
        self.name = name
        self.n_inputs = layer_config['n_inputs']
        self.n_outputs = layer_config['n_outputs']
        self.layer_config = layer_config
        self.network_config = network_config
        self.n_steps = network_config['n_steps']
        self.type = 'feedforward'
        self.mode = None

        super(Feedforward_Layer, self).__init__(self.n_inputs, self.n_outputs, bias=False)

        nn.init.kaiming_normal_(self.weight)
        self.weight = nn.Parameter(self.weight.cuda(), requires_grad=True)
        
    def get_parameters(self):
        return [self.weight]

    def init_weights(self):
        if self.mode == "finetune":
            nn.init.kaiming_normal_(self.weight)

    def set_mode(self, m):
        self.mode = m

    def get_mode(self):
        return self.mode

    def get_final_parameters(self):
        return [self.weight]

    def forward(self, x):
        y = linear_f(x, self.weight, self.bias)
        y = Feedforward_LIF.apply(y, self.network_config, self.layer_config)
        return y


class Direct_Layer(nn.Linear):

    def __init__(self, network_config, layer_config, name):
        self.name = name
        self.n_inputs = layer_config['n_inputs']
        self.n_outputs = layer_config['n_outputs']
        self.layer_config = layer_config
        self.network_config = network_config
        self.n_steps = network_config['n_steps']
        self.type = 'feedforward'
        self.mode = None

        super(Direct_Layer, self).__init__(self.n_inputs, self.n_outputs, bias=False)

        nn.init.kaiming_normal_(self.weight)
        self.weight = nn.Parameter(self.weight.cuda(), requires_grad=True)
        
    def get_parameters(self):
        return [self.weight]

    def init_weights(self):
        if self.mode == "finetune":
            nn.init.kaiming_normal_(self.weight)

    def set_mode(self, m):
        self.mode = m

    def get_mode(self):
        return self.mode

    def get_final_parameters(self):
        return [self.weight]

    def forward(self, x):
        y = linear_f(x, self.weight, self.bias)
        return y


class Pooling_Layer(nn.Conv3d):
    def __init__(self, network_config, config, name):
        self.name = name
        self.layer_config = config
        self.network_config = network_config
        self.type = config['type']
        kernel_size = config['kernel_size']
        self.mode = None
        if 'padding' in config:
            padding = config['padding']
        else:
            padding = 0

        if 'stride' in config:
            stride = config['stride']
        else:
            stride = None

        if 'dilation' in config:
            dilation = config['dilation']
        else:
            dilation = 1

        if 'theta' in config:
            theta = config['theta']
        else:
            theta = 1.1
        # kernel
        if type(kernel_size) == int:
            kernel = (kernel_size, kernel_size, 1)
        elif len(kernel_size) == 2:
            kernel = (kernel_size[0], kernel_size[1], 1)
        else:
            raise Exception('kernelSize can only be of 1 or 2 dimension. It was: {}'.format(kernel_size.shape))

        # stride
        if stride is None:
            stride = kernel
        elif type(stride) == int:
            stride = (stride, stride, 1)
        elif len(stride) == 2:
            stride = (stride[0], stride[1], 1)
        else:
            raise Exception('stride can be either int or tuple of size 2. It was: {}'.format(stride.shape))

        # padding
        if type(padding) == int:
            padding = (padding, padding, 0)
        elif len(padding) == 2:
            padding = (padding[0], padding[1], 0)
        else:
            raise Exception('padding can be either int or tuple of size 2. It was: {}'.format(padding.shape))

        # dilation
        if type(dilation) == int:
            dilation = (dilation, dilation, 1)
        elif len(dilation) == 2:
            dilation = (dilation[0], dilation[1], 1)
        else:
            raise Exception('dilation can be either int or tuple of size 2. It was: {}'.format(dilation.shape))
        super(Pooling_Layer, self).__init__(1, 1, kernel, stride, padding, dilation, bias=False)

        self.weight = torch.nn.Parameter(1 * theta * torch.ones(self.weight.shape).cuda(), requires_grad=False)

    def forward(self, x):
        result = F.conv3d(x.reshape((x.shape[0], 1, x.shape[1] * x.shape[2], x.shape[3], x.shape[4])),
                          self.weight, self.bias,
                          self.stride, self.padding, self.dilation)
        return result.reshape((result.shape[0], x.shape[1], -1, result.shape[3], result.shape[4]))

    def get_parameters(self):
        return []

    def weight_clipper(self):
        return

    def set_mode(self, m):
        self.mode = m

    def get_mode(self):
        return self.mode

    def init_weights(self):
        return

