from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import os
import time
from event_data_reader import N_Caltech_Dataset
import numpy as np
import matplotlib.pyplot as plt
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nb_outputs = 101
nb_steps = 100
max_time = 1.4
thresh = 0.5 # neuronal threshold
lens = 0.5 # hyper-parameters of approximate function
num_classes = nb_outputs
batch_size  = 10
learning_rate = 1e-3
num_epochs = 300# max epoch
lr_decay_epoch= 150
refraction = 1
dt = 1
input_sample_window = 10
input_num_of_frame = 20
reduced_training = 30
dataset_file_loc = ''
annotation_file_loc = ''

class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()


act_fun = ActFun.apply
# membrane potential update
def mem_update_old(ops, x, mem, spike, decay):
    mem = mem * decay * (1. - spike) + ops(x)
    spike = act_fun(mem) # act_fun : approximation firing function
    return mem, spike

def mem_update(ops, x, mem, spike, lstm_enable, this_decay, timer_matrix, R_m, tau_m_inv):

    #print(mem.size(), ops(x).size(), spike.size())
    if lstm_enable:
        mem = mem * (1. - spike)
        delta_mem = dt*(-mem.clone()-this_decay+ops(x)*R_m)*tau_m_inv
        mem = mem + delta_mem
        mem = F.relu(mem)
    else:
        mem = mem * this_decay * (1. - spike) + ops(x)*0.3
    spike = act_fun(mem) # act_fun : approximation firing function
    timer_matrix = spike*refraction + timer_matrix
    return mem, spike, timer_matrix


# Dacay learning_rate
def lr_scheduler(optimizer, epoch, init_lr, lr_decay_epoch):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    if epoch % lr_decay_epoch == 0 and epoch > 1:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1
    return optimizer

cfg_cnn = [(2, 16, 2, 1, 3),
           (16, 16, 1, 1, 3),
           (16, 16, 1, 1, 3)]
cfg_kernel = [120, 40, 13, 3]
cfg_fc = [128, num_classes]


class heterogeneous_with_skip(nn.Module):
    def __init__(self, num_of_layer=4, depth_scale=1, num_of_dynamic=1, decay_list=[0.9], skip_layer_connection_list=[], R_m=60, tau_m_inv=1/110):
        super(heterogeneous_with_skip, self).__init__()
        print("Using BP-LS_SNN")
        
        if num_of_layer<4:
            print("Using a layer num less than 4.")
            exit()
        if depth_scale<=0:
            print("Using depth scale less than 0.")
            exit()

        self.num_of_layer = num_of_layer
        self.depth_scale = depth_scale
        self.num_of_dynamic = num_of_dynamic
        self.decay_list = decay_list

        self.R_m = R_m
        self.tau_m_inv = tau_m_inv

        self.skip_conn = skip_layer_connection_list
     

        self.addition_cfg_cnn = (cfg_cnn[2][1]*depth_scale, 16*depth_scale, 1, 1, 3)
        self.conv_list = torch.nn.ModuleList()
        first_res_con = 1
        first_res_con_dim = 0
         
        for i in range(self.num_of_layer):
            if i==0: self.conv_list.append(nn.Conv2d(cfg_cnn[0][0], cfg_cnn[0][1]*depth_scale, kernel_size=cfg_cnn[0][4], stride=cfg_cnn[0][2], padding=cfg_cnn[0][3]))
            elif i in self.skip_conn: 
                if first_res_con:
                    self.conv_list.append(nn.Conv2d(self.addition_cfg_cnn[0]*self.num_of_dynamic, self.addition_cfg_cnn[1], kernel_size=self.addition_cfg_cnn[4], stride=self.addition_cfg_cnn[2], padding=self.addition_cfg_cnn[3]))
                    first_res_con = 0
                else:
                    self.conv_list.append(nn.Conv2d(self.addition_cfg_cnn[0]*self.num_of_dynamic*2, self.addition_cfg_cnn[1], kernel_size=self.addition_cfg_cnn[4], stride=self.addition_cfg_cnn[2], padding=self.addition_cfg_cnn[3]))
            else: self.conv_list.append(nn.Conv2d(self.addition_cfg_cnn[0]*self.num_of_dynamic, self.addition_cfg_cnn[1], kernel_size=self.addition_cfg_cnn[4], stride=self.addition_cfg_cnn[2], padding=self.addition_cfg_cnn[3]))

        full_size = [batch_size, self.addition_cfg_cnn[1], cfg_kernel[-2], cfg_kernel[-2]];
        half_size = full_size[1]//2
       

        print("fc1 size:", cfg_kernel[-1] * cfg_kernel[-1] * self.addition_cfg_cnn[1]*self.num_of_dynamic)
        self.fc1 = nn.Linear(cfg_kernel[-1] * cfg_kernel[-1] * self.addition_cfg_cnn[1]*self.num_of_dynamic, cfg_fc[0])
        self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1])



    def forward(self, input, time_window = 2):

        c_spike_list = []
        c_timer_list = []
        c_mem_list = []

        for i in range(3):
            c_mem_list_in_list = []
            c_spike_list_in_list = []
            c_timer_list_in_list = []

            for j in range(self.num_of_dynamic):
                c_mem_list_in_list.append(torch.zeros(batch_size, cfg_cnn[i][1]*self.depth_scale, cfg_kernel[i], cfg_kernel[i], device=device))
                c_spike_list_in_list.append(torch.zeros(batch_size, cfg_cnn[i][1]*self.depth_scale, cfg_kernel[i], cfg_kernel[i], device=device, dtype=torch.int8))
                c_timer_list_in_list.append(torch.zeros(batch_size, cfg_cnn[i][1]*self.depth_scale, cfg_kernel[i], cfg_kernel[i], device=device, dtype=torch.int8))
            c_mem_list.append(c_mem_list_in_list)
            c_spike_list.append(c_spike_list_in_list)
            c_timer_list.append(c_timer_list_in_list)

        for i in range(self.num_of_layer-3):
            c_mem_list_in_list = []
            c_spike_list_in_list = []
            c_timer_list_in_list = []
            for j in range(self.num_of_dynamic):
                c_mem_list_in_list.append(torch.zeros(batch_size, self.addition_cfg_cnn[1], cfg_kernel[-2], cfg_kernel[-2], device=device))
                c_spike_list_in_list.append(torch.zeros(batch_size, self.addition_cfg_cnn[1], cfg_kernel[-2], cfg_kernel[-2], device=device, dtype=torch.int8))
                c_timer_list_in_list.append(torch.zeros(batch_size, self.addition_cfg_cnn[1], cfg_kernel[-2], cfg_kernel[-2], device=device, dtype=torch.int8))
            
            c_mem_list.append(c_mem_list_in_list)
            c_spike_list.append(c_spike_list_in_list)
            c_timer_list.append(c_timer_list_in_list)

        h1_mem = h1_spike = h1_sumspike =  torch.zeros(batch_size, cfg_fc[0], device=device)
        h1_timer =  torch.zeros(batch_size, cfg_fc[0], device=device, dtype=torch.int8, requires_grad=False)
        h2_mem = h2_spike = h2_sumspike =  torch.zeros(batch_size, cfg_fc[1], device=device)
        h2_timer =  torch.zeros(batch_size, cfg_fc[1], device=device, dtype=torch.int8, requires_grad=False)


        input_size = input.size()

        sim_time = 0
        last_res = 0

        for img in range(input_size[1]):
         one_img =  input[:,img,:,:]
         for step in range(time_window): # simulation time steps
            x = one_img > torch.rand(one_img.size(), device=device) # prob. firing

            i = sim_time%(self.num_of_layer)

            if i==0: 
                for j in range(self.num_of_dynamic):
                    c_mem_list[i][j], c_spike_list[i][j], c_timer_list[i][j] = mem_update(self.conv_list[i], x.float(), c_mem_list[i][j], c_spike_list[i][j], True, self.decay_list[j], c_timer_list[i][j], self.R_m, self.tau_m_inv)
            else: 
                if i==1 or i==2:
                    x = F.avg_pool2d(c_spike_list[i-1][0], 3).type(torch.cuda.FloatTensor)
                else:
                    x = c_spike_list[i-1][0].type(torch.cuda.FloatTensor)
                for j in range(self.num_of_dynamic-1):
                    if i==1 or i==2:
                        x = torch.cat((x, F.avg_pool2d(c_spike_list[i-1][j+1], 3).type(torch.cuda.FloatTensor)), 1)
                    else:
                        x = torch.cat((x, c_spike_list[i-1][j+1].type(torch.cuda.FloatTensor)), 1)


                
                if i in self.skip_conn[:-1]:
                    res_x = x


                if i-1 in self.skip_conn[:-1]:
                    target_layer = self.skip_conn[self.skip_conn.index(i-1)+1]
                    pre_layer_shape = list(c_spike_list[target_layer-1][0].shape)
                    pre_layer_shape[1] *= self.num_of_dynamic
                    x_temp = torch.cat((torch.zeros(pre_layer_shape,device=device, dtype=torch.int8), res_x), dim=1)
                    for j in range(self.num_of_dynamic):
                        c_mem_list[target_layer][j], c_spike_list[target_layer][j], c_timer_list[target_layer][j] = mem_update(self.conv_list[target_layer], x_temp, c_mem_list[target_layer][j], c_spike_list[target_layer][j], True, self.decay_list[j], c_timer_list[target_layer][j], self.R_m, self.tau_m_inv)

                
                for j in range(self.num_of_dynamic):
                    if i in self.skip_conn[1:]:
                        skip_layer_idx = self.skip_conn[self.skip_conn.index(i)-1]
                        skip_layer_shape = list(c_spike_list[skip_layer_idx][0].shape)
                        skip_layer_shape[1] *= self.num_of_dynamic
                        temp_x = torch.cat((x, (torch.zeros(skip_layer_shape, device=device, dtype=torch.int8))), dim=1)
                        c_mem_list[i][j], c_spike_list[i][j], c_timer_list[i][j] = mem_update(self.conv_list[i], temp_x, c_mem_list[i][j], c_spike_list[i][j], True, self.decay_list[j], c_timer_list[i][j], self.R_m, self.tau_m_inv)
                    else:
                        c_mem_list[i][j], c_spike_list[i][j], c_timer_list[i][j] = mem_update(self.conv_list[i], x, c_mem_list[i][j], c_spike_list[i][j], True, self.decay_list[j], c_timer_list[i][j], self.R_m, self.tau_m_inv)
                    
            sim_time += 1
            x = c_spike_list[-1][0].type(torch.cuda.FloatTensor)
            for j in range(self.num_of_dynamic-1):
                x = torch.cat((x, c_spike_list[-1][j+1].type(torch.cuda.FloatTensor)), 1)    
            x = F.avg_pool2d(x, 4).view(batch_size, -1)

            h1_mem, h1_spike, _  = mem_update(self.fc1, x, h1_mem, h1_spike, False, 0.9, h1_timer, self.R_m, self.tau_m_inv)
            h1_sumspike += h1_spike
            h2_mem, h2_spike, _  = mem_update(self.fc2, h1_spike, h2_mem, h2_spike, False, 0.9, h2_timer, self.R_m, self.tau_m_inv)
            h2_sumspike += h2_spike

        outputs = h2_sumspike / time_window / input_size[1]

        return outputs







class snn_network_object():
    
    def __init__(self, num_of_layer=12, depth_scale=3, skip_layer_start=4, skip_layer_end=10, skip_layer_gap=3, dynamic_no=1, decay_rate=[-100], R_m=50, tau_m=100, save_model=False):
        super(snn_network_object, self).__init__()
        print("Initializing a network object")
        self.snn = []
        self.best_acc = 1e-6
        self.skip_layer_connection_list = []
        self.total_layer_num = num_of_layer
        self.depth_scale = depth_scale
        self.valid_config = True
        self.decay_rate  = decay_rate
        self.save_model = save_model
        
        self.R_m = R_m
        self.tau_m_inv = 1/tau_m

        self.dynamic_no = dynamic_no
        self.decay_rate = decay_rate

        

        self.valid_config = self.skip_layer_list_filler(total_layer_num=num_of_layer, start=skip_layer_start, end=skip_layer_end, gap=skip_layer_gap)


        print("Params: num of layer ", num_of_layer, " depth scale ", depth_scale, " skip start ", skip_layer_start, " end ", skip_layer_end, " gap: ", skip_layer_gap, " skip-layer-conn: ", self.skip_layer_connection_list, " dynamic Num ", dynamic_no, " decay list ", decay_rate, " R_m ", R_m, "Tau_m", tau_m)




    def skip_layer_list_filler(self, total_layer_num, start, end, gap):
        assert type(total_layer_num) == int
        assert type(start) == int
        assert type(end) == int

        if (gap<2): return False
        if (gap>(end-start)): return False
        if (start>=(end-1)): return False
        if total_layer_num<=3: return True


        current_idx = start
        while (current_idx<total_layer_num) and (current_idx<=end):
            self.skip_layer_connection_list.append(current_idx)
            current_idx += gap
        return True



    def training_process(self):


        if not self.valid_config:

            print("Invalid config")
            return

        if self.save_model:
            names = 'bpsmlp_smnist_withRes'+'_'+str(self.total_layer_num)+"LayerX"+str(self.depth_scale)
            #data_path =  './raw/' #todo: input your data path
            save_dir = './checkpoint/temporal_dataset/' + names
            save_dir_index = 0
            #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            while os.path.isdir(save_dir+"_"+str(save_dir_index)):
                save_dir_index += 1
            save_dir = save_dir+"_"+str(save_dir_index)
            if not os.path.isdir(save_dir):
                os.mkdir(save_dir)
            print("save dir:", save_dir)


        
        number_of_class = num_classes


        dataset = N_Caltech_Dataset(annotations_file=annotation_file_loc, img_dir=dataset_file_loc, sample_window=input_sample_window, num_of_frame=input_num_of_frame)
        trainset, testset = torch.utils.data.random_split(dataset, [(dataset.__len__()-dataset.__len__()//10-10), dataset.__len__()//10+10], generator=torch.Generator().manual_seed(42))
        validationset, _ = torch.utils.data.random_split(testset, [int(testset.__len__()/2)//batch_size*batch_size, (testset.__len__()-int(testset.__len__()/2)//batch_size*batch_size)], generator=torch.Generator().manual_seed(42))
        if reduced_training>0 and reduced_training<100: 
            target_training_size = int(trainset.__len__()*reduced_training/100//batch_size*batch_size)
            trainset, _ = torch.utils.data.random_split(trainset, [target_training_size, trainset.__len__()-target_training_size], generator=torch.Generator().manual_seed(42))

        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=16)
        validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=False, num_workers=16)
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch
        get_train_acc = 0
        print_interval = 100

        acc_train_record = []
        acc_test_record = []
        loss_train_record = []
        loss_test_record = []
        self.snn = heterogeneous_with_skip(num_of_layer=self.total_layer_num, depth_scale=self.depth_scale, num_of_dynamic=len(self.decay_rate), decay_list=self.decay_rate, skip_layer_connection_list=self.skip_layer_connection_list, R_m=self.R_m, tau_m_inv=self.tau_m_inv)

        self.snn.to(device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.snn.parameters(), lr=learning_rate)
        best_loss = 1e7
        training_start_time = time.time()

        

        for epoch in range(num_epochs):
            running_loss = 0
            start_time = time.time()
            train_losses = []
            save_epoch = 0

            correct = 0
            correct_class = 0
            correct_rotation = 0
            correct_translation = 0
            total = 0
            i = 0
            for input_data, labels_ in train_loader:
                self.snn.zero_grad()
                optimizer.zero_grad()

                input_data = input_data.float().to(device)
                outputs = self.snn(input_data)
                labels_ = labels_.to(device)
                labels_ = torch.nn.functional.one_hot(labels_, num_classes=number_of_class)
                loss = criterion(outputs, labels_.float())*10
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
                if (i+1)%print_interval == 0:
                     print ('Epoch [%d/%d]'%(epoch+1, num_epochs ))
                     print('Time elasped:', time.time()-start_time)
                     print('train loss : {:.4f}'.format(np.mean(train_losses) ))

                train_losses.append(loss.item())   

                if get_train_acc:
                    if number_of_class>0: class_outputs = outputs[:,0:number_of_class]
                    if number_of_class>0: _, predicted_class = class_outputs.max(1)
                    if number_of_class>0: _, target_class = labels_[:,0:number_of_class].max(1)

                    total += labels.size(0)
                    if number_of_class>0: 
                     correct_class += predicted_class.eq(target_class).sum().item()
                     all_correct_matrix = predicted_class.eq(target_class)
                    else: all_correct_matrix = predicted_rotation.eq(target_rotation)

              
                    correct += all_correct_matrix.sum().item()


                i += 1
            
            mean_train_loss = np.mean(train_losses)
            loss_train_record.append(mean_train_loss)
            if get_train_acc: acc_train_record.append(100.*correct/total)

            if mean_train_loss<best_loss:
                best_loss = mean_train_loss
                save_epoch = 1


            correct = 0
            correct_class = 0
            total = 0
            test_losses = []
            optimizer = lr_scheduler(optimizer, epoch, learning_rate, lr_decay_epoch)

            with torch.no_grad():
                for inputs, targets in validation_loader:

                    inputs = inputs.float().to(device)
                    optimizer.zero_grad()
                    outputs = self.snn(inputs)
                    targets = targets.to(device)
                    targets = torch.nn.functional.one_hot(targets, num_classes=number_of_class)
                    loss = criterion(outputs, targets.float())
                    test_losses.append(loss.item())

                    if number_of_class>0: class_outputs = outputs[:,0:number_of_class]

                    if number_of_class>0: _, predicted_class = class_outputs.max(1)

                    if number_of_class>0: _, target_class = targets[:,0:number_of_class].max(1)


                    total += targets.size(0)
                    if number_of_class>0: 
                     correct_class += predicted_class.eq(target_class).sum().item()
                     all_correct_matrix = predicted_class.eq(target_class)

                    correct += all_correct_matrix.sum().item()

            mean_test_loss = np.mean(test_losses)

            acc = 100. * float(correct) / float(total)
            print("Acc: ", acc)
            acc = 100.*correct/total
            acc_class = 0
            if number_of_class>0: acc_class = 100.*correct_class/total
            acc_rotation = 0


            loss_test_record.append(mean_test_loss)
            acc_test_record.append(acc)
            
            if acc>self.best_acc:
                print("New Best Acc:", acc, " Epoch:", epoch)
                self.best_acc = acc

                if self.save_model:
                    state = {
                    'net': self.snn.state_dict(),
                    'acc': acc,
                    'epoch': epoch,
                    'best_acc': self.best_acc
                    }
                    torch.save(state, save_dir +'/best_acc'+ '.t7')  
                    print("Model Saved")


            
        
    def get_best_acc(self):
        return self.best_acc


