import torch
import global_v as glv


class Recurrent_Dendrite(torch.autograd.Function):  # a and u is the incremnet of each time steps

    @staticmethod
    def forward(ctx, y, t):
        ctx.save_for_backward(torch.tensor([t]))
        return y, y[..., t]

    def backward(ctx, grad_y, grad):
        (others) = ctx.saved_tensors
        t = others[0].item()

        grad_y_out = grad_y.clone()
        grad_y_out[..., t] = grad

        return grad_y_out, None


class Recurrent_LIF(torch.autograd.Function):  # a and u is the incremnet of each time steps
    @staticmethod
    def forward(ctx, syns, outputs, mem, mem_pre, theta, cal, R, network_config, layer_config, t, ip):
        shape = outputs.shape
        n_steps = glv.n_steps 
        threshold = layer_config['threshold']
        tau_s = network_config['tau_s']
        tau_m = network_config['tau_m']
        theta_s = 1/tau_s

        mem_update = mem - mem_pre

        outputs_out = outputs.clone()
        syns_out = syns.clone()
        theta_out = theta.clone()
        cal_out = cal.clone()
        R_out = R.clone()

        out = mem > threshold
        out = out.type(torch.float32)

        if ip:
            cal_out = cal_out - cal_out/16 + out
            y = cal_out / 16
            lr = 5
            beta = 5
            W=threshold/(torch.exp(1/(theta_out * y))-1)
            
            theta_out[y<=0.01] -= lr * 0.5
            theta_out[y>0.01] += lr * (-1+5*y[y>0.01])/theta_out[y>0.01]

            R_out[y<=0.01] += lr * 0.5
            R_out[y>0.01] += lr * ((2 - beta * y[y>0.01]) * y[y>0.01] * theta[y>0.01] * threshold - W[y>0.01] - threshold) / (R_out[y>0.01] * W[y>0.01]) 

            theta_out[theta_out<4] = 4
            theta_out[theta_out>128] = 128
            R_out[R_out<4] = 4
            R_out[R_out>128] = 128

        if t == 0:
            response = out * theta_s
        else:
            response = syns[..., t-1]
            response = response + (out - response) * theta_s
        syns_out[..., t] = response
        outputs_out[..., t] = out
        ctx.save_for_backward(mem_update, out, mem, torch.tensor([threshold, t, tau_s]))
        mem = mem * (1-out)

        return syns_out, outputs_out, response, mem, theta_out, cal_out, R_out

    @staticmethod
    def backward(ctx, grad_forward, grad_recurrent, grad_response, grad_mem, grad_theta, grad_cal, grad_R):
        (delta_u, out, u, others) = ctx.saved_tensors
        shape = delta_u.shape

        threshold = others[0].item()
        t = int(others[1].item())
        tau_s = others[2].item()
        n_steps = glv.n_steps

        time_end = n_steps
        time_len = time_end-t

        # time_len = int(min(n_steps - t, 4*tau_s))
        # time_end = t + time_len

        if t == n_steps - 1:
            grad_recurrent_out = torch.zeros_like(grad_forward) 
        else:
            grad_recurrent_out = grad_recurrent.clone()
            grad_recurrent_out[..., t+1] = grad_response

        syn_a = glv.syn_a.repeat(shape[0], shape[1], shape[2], shape[3], 1)

        a = 0.2
        f = torch.clamp((-1 * u + threshold) / a, -8, 8)
        f = torch.exp(f)
        f = f / ((1 + f) * (1 + f) * a)
        f += torch.clamp(1/delta_u, 0, 8) * out/tau_s
            
        grad = torch.sum(syn_a[..., 0:time_len] * grad_forward[..., t:time_end], dim=-1)
        
        grad += torch.sum(syn_a[..., 0:time_len-1]*grad_recurrent[..., t+1:time_end], dim=4) 

        grad = grad * f

        return grad_forward, grad_recurrent_out, grad, None, None, None, None, None, None, None, None, None, None 


class Feedforward_LIF(torch.autograd.Function): 
    @staticmethod
    def forward(ctx, inputs, network_config, layer_config):
        shape = inputs.shape
        n_steps = shape[4] 
        tau_m = network_config['tau_m']
        tau_s = network_config['tau_s']
        theta_s = 1/tau_s
        threshold = layer_config['threshold']

        mem = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
        syn = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
        syns_posts = []
        mems = []
        mem_updates = []
        outputs = []
        for t in range(n_steps):
            mem_update = (-1/tau_m) * mem + inputs[..., t]
            mem += mem_update

            out = mem > threshold
            out = out.type(torch.float32)
            
            mems.append(mem)

            mem = mem * (1-out)
            outputs.append(out)
            mem_updates.append(mem_update)
            syn = syn + (out - syn) * theta_s
            syns_posts.append(syn)

        mems = torch.stack(mems, dim = 4)
        mem_updates = torch.stack(mem_updates, dim = 4)
        outputs = torch.stack(outputs, dim = 4)
        syns_posts = torch.stack(syns_posts, dim = 4)
        ctx.save_for_backward(mem_updates, outputs, mems, torch.tensor([threshold, tau_s]))

        return syns_posts

    @staticmethod
    def backward(ctx, grad_delta):
        (delta_u, outputs, u, others) = ctx.saved_tensors
        shape = grad_delta.shape
        n_steps = shape[4]
        threshold = others[0].item()
        tau_s = others[1].item()

        grad = torch.zeros_like(grad_delta)

        syn_a = glv.syn_a.repeat(shape[0], shape[1], shape[2], shape[3], 1)

        o = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
        
        theta = torch.zeros(shape[0], shape[1], shape[2], shape[3]).cuda()
        for t in range(n_steps-1, -1, -1): 
            # time_end = t + time_len
            # time_len = int(min(n_steps - t, 4*tau_s))

            time_end = n_steps
            time_len = time_end-t

            out = outputs[..., t]

            a = 0.2
            f = torch.clamp((-1 * u[..., t] + threshold) / a, -8, 8)
            f = torch.exp(f)
            f = f / ((1 + f) * (1 + f) * a)
            f += torch.clamp(1/delta_u[..., t], 0, 8) * out/tau_s
            
            grad_a = torch.sum(syn_a[..., 0:time_len]*grad_delta[..., t:time_end], dim=-1)

            grad_a = grad_a * f

            grad[..., t] = grad_a

        return grad, None, None, None
    
