import torch
import tools.global_v as glv


class TSSLBP(torch.autograd.Function): 
    @staticmethod
    def forward(ctx, inputs, network_config, layer_config, name, target_mem_train, theta_v ):
        device = inputs.device
        shape = inputs.shape
        n_steps = shape[4] 
        theta_m = 1/network_config['tau_m']
        tau_s = network_config['tau_s']
        theta_s = 1/tau_s

        #decide alif mode
        if(name == "output"):
            is_alif = False
            fix_target = True
        else:
            is_alif = (network_config['model'] == "ALIF")
            fix_target =  network_config['fix_target']

        syns_posts = []
        mems = []
        mem_updates = []
        outputs = []
        thresholds = []
        
        with torch.no_grad():

            mem = torch.zeros(shape[0], shape[1], shape[2], shape[3]).to(device)
            syn = torch.zeros(shape[0], shape[1], shape[2], shape[3]).to(device)
            out = torch.zeros(shape[0], shape[1], shape[2], shape[3]).to(device)
            threshold = layer_config['threshold'] * torch.ones(shape[0], shape[1], shape[2], shape[3]).to(device)
            mem_train_error_sum = torch.zeros(shape[0], shape[1], shape[2], shape[3]).to(device)

            if(is_alif):
                mem_train_error = torch.zeros(shape[0], shape[1], shape[2], shape[3]).to(device)
                target_mem_train = target_mem_train.expand(shape[0], -1, -1, -1, -1).to(device)
            else:
                mem_train_error = None

            for t in range(n_steps):
                mem_update = (-theta_m) * mem + inputs[..., t]
                mem += mem_update

                #update threshold
                if(is_alif):
                    mem_train_error = mem - target_mem_train[...,t]
                    threshold = threshold + theta_v * mem_train_error
                    mem_train_error_sum += mem_train_error
                thresholds.append(threshold)
                out = mem > threshold
                out = out.type(torch.float32)
                outputs.append(out)
                mems.append(mem)
                mem = mem * (1-out)
                mem_updates.append(mem_update)
                syn = syn + (out - syn) * theta_s
                syns_posts.append(syn)


            outputs = torch.stack(outputs, dim = 4)
            mems = torch.stack(mems, dim = 4)
            mem_updates = torch.stack(mem_updates, dim = 4)
            syns_posts = torch.stack(syns_posts, dim = 4)
            thresholds = torch.stack(thresholds, dim = 4)
            other_tensor =  torch.tensor([layer_config['threshold'], is_alif, fix_target, network_config["dynamic_v"]])
        
        ctx.save_for_backward(mem_updates, outputs, mems, other_tensor, theta_v, mem_train_error_sum/n_steps, thresholds)
        ctx.rule = network_config['rule']
        if(bool(network_config["save_target"]) == True):
            return syns_posts, torch.mean(mems, dim=0)
        else:
            return syns_posts, None
    @staticmethod
    def backward(ctx, grad_delta, grad_outputs):
        (delta_u, outputs, u, others, theta_v, ei, thresholds) = ctx.saved_tensors

        device = outputs.device 
        shape = grad_delta.shape
        n_steps = shape[4]
        threshold = others[0].item()
        is_alif = bool(others[1].item())
        fix_target = bool(others[2].item())
        dynamic_v = bool(others[3].item())

        partial_a_tmp = glv.partial_a[..., 0, :].repeat(shape[0], shape[1], shape[2], shape[3], 1).to(device)
        grad_a = torch.empty_like(delta_u)
        for t in range(n_steps):
            grad_a[..., t] = torch.sum(partial_a_tmp[..., 0:n_steps-t]*grad_delta[..., t:n_steps], dim = 4)


        if(ctx.rule == "TSSLBP"):
            th = torch.sum(outputs)/(shape[0] * shape[1] * shape[2] * shape[3] * shape[4]) > 0.1
        elif(ctx.rule == "SGRAD"):
            th = False
        else:
            exit("Wrong lr rule")
        if th:
            #cal tsslbp
            partial_u = torch.clamp(1 / delta_u, -10, 10) * outputs
            grad = grad_a * partial_u
        else:
            #cal surgrad
            a = 0.2
            if(dynamic_v == True):
                f = torch.sigmoid(torch.clamp(-(u - thresholds) / a, -8, 8))
            else:
                f = torch.sigmoid(torch.clamp(-(u - threshold) / a, -8, 8))
            grad = grad_a * f * (1-f)/a
        
        if(is_alif):
            grad_v =torch.mean(-grad,dim= 0)
            grad_theta_v = 0.1*torch.mean(grad_v,dim= -1) * ei
            if(fix_target):
                grad_target = None
            else:
                grad_target = 0.1*grad_v *(- theta_v.unsqueeze(-1)) 
        else:
            grad_theta_v = None
            grad_target = None            

            
        return grad, None, None, None, grad_target, grad_theta_v


