from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import torch, time, os
import torch.nn as nn
import numpy as np
import random
import argparse

parser = argparse.ArgumentParser(description='adjust')
parser.add_argument('--window', type=int, default=5)
parser.add_argument('--repeat', type=int, default=100)
args = parser.parse_args()
thresh = 0.3
lens = 0.5
decay = 0.5
num_classes = 10
batch_size = 100
num_epochs = 400
learning_rate = 1e-3
time_window = args.window

device = torch.device("cuda:0")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed) 
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

class ActFun(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(0.).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input) < lens
        return grad_input * temp.float()


class Layer:
    def __init__(self, device=torch.device('cuda:0')):
        self.device = device
        self.cache = None

    def forward(self, input):
        pass

    def backward(self, loss):
        pass

    def init(self, init_name, weight_size, mode="xavier_normal", act="TanH"):
        weight_size = torch.tensor(weight_size)
        fan_in = torch.prod(weight_size[:-1])
        fan_out = weight_size[-1]
        gain = 1.
        if init_name == "xavier_normal":
            std = gain * np.sqrt(2.0 / float(fan_in + fan_out))
            return torch.randn(size=tuple(weight_size), device=self.device) * std
        elif init_name == "kaiming_normal":
            if act == "TanH":
                gain = 5 / 3
            elif act == "ReLU":
                gain = np.sqrt(2)
            fan = fan_in if mode == 'fan_in' else fan_out
            std = gain / np.sqrt(fan)
            return torch.randn(size=tuple(weight_size), device=self.device) * std


class Linear(Layer):

    def __init__(self, input_features, output_features, noise_std=1e-0,
                 act='ReLU', device=torch.device('cuda:0')):
        super().__init__(device)
        self.input_features = input_features
        self.output_features = output_features
        self.act = act
        self.gradient_w = 0.
        self.gradient_b = 0.
        self.count =0.
        self.w_size = (self.input_features, self.output_features)
        self.b_size = (self.output_features,)
        self.W = {'val': self.init("xavier_normal", self.w_size, act=self.act), 'grad': 0.}
        self.b = {'val': self.init("xavier_normal", self.b_size, act=self.act), 'grad': 0.}

        self.noise_std = noise_std
        self.noise = None

    def forward(self, input, train=True):
        self.cache = input
        
        tmp = input @ self.W['val'] + self.b['val']
        if train!=True:
            return tmp
        self.noise = torch.randn(size=tmp.shape, device=self.device) * self.noise_std
        
        self.gradient_w +=  torch.einsum('ni, nj->nij',input, self.noise) / (self.noise_std ** 2)
        b_term =  torch.ones(size=[len(self.noise)], device=self.device)
        
        self.gradient_b += torch.einsum('n, nj->nj', b_term, self.noise) / self.noise_std
        self.count +=1
        return tmp + self.noise

    def backward(self, loss):
        input = self.cache

        w_batch_grad = self.gradient_w * loss[:,np.newaxis,np.newaxis]
        w_batch_grad = torch.einsum('nij->ij', w_batch_grad) / self.count
        self.W['grad'] = w_batch_grad / self.count

        b_batch_grad =  self.gradient_b * loss[:,np.newaxis]
        b_batch_grad = torch.einsum('nj->j',b_batch_grad)
        
        self.b['grad'] = b_batch_grad / self.count
        self.count = 0.
        self.gradient = 0.
        self.gradient_b = self.gradient_w = 0.
        return self.W['grad'], self.b['grad']


cfg_fc = [50, 10]
probs = 0.5
act_fun = ActFun.apply


def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=100):
    if epoch % lr_decay_epoch == 0 and epoch > 1:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5
    return optimizer


class Net:
    def __init__(self, device=torch.device('cuda:0')):
        self.device = device
        self.layers = []

    def forward(self, input):
        pass

    def backward(self, loss):
        grads = {}
        for i, layer in enumerate(self.layers):
            grads['grad_W_' + str(i + 1)], grads['grad_b_' + str(i + 1)] = layer.backward(loss)
        return grads

    def get_params(self):
        params = {}
        for i, layer in enumerate(self.layers):
            params['W_' + str(i + 1)] = layer.W['val']
            params['b_' + str(i + 1)] = layer.b['val']
        return params

    def set_params(self, params):
        for i, layer in enumerate(self.layers):
            layer.W['val'] = params['W_' + str(i + 1)]
            layer.b['val'] = params['b_' + str(i + 1)]


class LrScheduler:

    def __init__(self, step_size, gamma):
        self.epoch = 0
        self.step_size = step_size
        self.gamma = gamma

    def step(self, lr):
        self.epoch += 1
        if (self.epoch % self.step_size) == 0:
            return lr * self.gamma
        else:
            return lr


class Optimizer:
    def __init__(self, lr, params, device=torch.device("cuda:0")):
        self.lr = lr
        self.params = params
        self.device = device

    def update_params(self, grads):
        pass

    def update_lr(self, lr):
        self.lr = lr


class SGD(Optimizer):

    def __init__(self, lr, params, device=torch.device("cuda:0")):
        super().__init__(lr, params, device)

    def update_params(self, grads):
        for key in self.params:
            self.params[key] = self.params[key] - self.lr * grads['grad_' + key]
        return self.params


class AdamGD(Optimizer):

    def __init__(self, lr, params, beta1=0.9, beta2=0.999, epsilon=1e-8, device=torch.device("cuda:0")):
        super().__init__(lr, params, device)
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon

        self.momentum = {}
        self.rmsprop = {}

        for key in self.params:
            self.momentum['vd' + key] = torch.zeros(self.params[key].shape, device=self.device)
            self.rmsprop['sd' + key] = torch.zeros(self.params[key].shape, device=self.device)

    def update_params(self, grads):

        for key in self.params:
            # Momentum update.
            self.momentum['vd' + key] = (self.beta1 * self.momentum['vd' + key]) + (1 - self.beta1) * grads['grad_' + key]
            # RMSprop update.
            self.rmsprop['sd' + key] = (self.beta2 * self.rmsprop['sd' + key]) + (1 - self.beta2) * (
                    grads['grad_' + key] ** 2)
            # Update parameters.
            self.params[key] = self.params[key] - (self.lr * self.momentum['vd' + key]) / (
                    torch.sqrt(self.rmsprop['sd' + key]) + self.epsilon)

        return self.params

def repeat_data(data, repeat):
    data = data.reshape(len(data), -1)
    repeated_data = torch.cat([data] * repeat, dim=0)
    return repeated_data

class SNN_Model(Net):
    def __init__(self, p=0.5):
        super(SNN_Model, self).__init__()
        self.fc1 = Linear(784, cfg_fc[0], )
        self.fc2 = Linear(cfg_fc[0], cfg_fc[1], )
        self.layers = [self.fc1, self.fc2]

    def forward(self, input,  wins=time_window,train=True):
        batch_size = input.size(0)
        h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device)
        h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device)
        for step in range(wins):
            x = input.view(batch_size, -1).float()
            h1_spike, h1_mem  = mem_update(self.fc1, x, h1_spike, h1_mem,train)
            h2_spike, h2_mem  = mem_update(self.fc2, h1_spike, h2_spike, h2_mem,train)
            h2_sumspike = h2_sumspike + h2_spike
        outs = h2_sumspike / wins
        return outs, h2_sumspike

def mem_update(fc, inputs, spike, mem,train):
    state = fc.forward(inputs,train)
    mem = mem * (1 - spike) * decay + state
    now_spike = act_fun(mem - thresh)
    return now_spike.float(), mem

data_path =  r'./'
saving_names = '.'

train_dataset = torchvision.datasets.MNIST(root= data_path, train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

test_set = torchvision.datasets.MNIST(root= data_path, train=False, download=True,  transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

best_acc = 0.  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
acc_record = list([])
loss_train_record = list([])
loss_test_record = list([])
criterion = nn.MSELoss()
total_best_acc = []
total_acc_record = []
total_hid_state = []
train_acc_record = list([])
test_acc_record = list([])
snn = SNN_Model()

optimizer = AdamGD(lr=learning_rate, params=snn.get_params(), device=device)
acc_record = []
for epoch in range(num_epochs):
    running_loss = 0.
    
    start_time = time.time()
    total = 0.
    correct = 0.
    repeat = args.repeat
    for i, (images, targets) in enumerate(train_loader):
        

        images = images.float().to(device)
        images= repeat_data(images, repeat)
        targets = repeat_data(targets,repeat)
        targets = targets.reshape(-1)
        train_batch_size = len(targets)
        outputs, spikes = snn.forward(input=images , wins = time_window)
        targets_ = torch.zeros(train_batch_size, 10).scatter_(1, targets.view(-1, 1), 1).to(device)
        
        outputs = torch.nn.functional.softmax(outputs)
        loss = -torch.sum(targets_ * torch.log(outputs), dim=-1)
        running_loss += loss.sum().item()
        grads = snn.backward(loss)
        optimizer.update_lr(learning_rate)
        params = optimizer.update_params(grads)
        snn.set_params(params)
        _, predicted = outputs.cpu().max(1)
        total += float(targets.size(0))
        correct += float(predicted.eq(targets).sum().item())

    print ('Epoch [%d/%d], Correct:%.2f Step [%d/%d], Loss: %.5f'%(epoch+1,num_epochs,correct/total*100, i+1, len(train_dataset)//train_batch_size,running_loss/total ))
    loss_train_record.append(running_loss/total)
    train_acc_record.append(correct/total*100)

    correct = 0.
    total = 0.
    running_loss = 0.
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            outputs, sumspike = snn.forward(input=inputs,  wins=time_window,train=False)
            targets_ = torch.zeros(len(targets), 10).scatter_(1, targets.view(-1, 1), 1).to(device)
            outputs = torch.nn.functional.softmax(outputs)
            loss = -torch.sum(targets_ * torch.log(outputs), dim=-1)
            running_loss +=loss.sum().item()
            _, predicted = outputs.cpu().max(1)
            total += float(targets_.size(0))
            correct += float(predicted.eq(targets).sum().item())

        acc = 100. * float(correct) / float(total)

    print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))
    loss_test_record.append(running_loss/total)
    acc = 100. * float(correct) / float(total)
    acc_record.append(acc)
    test_acc_record.append(100 * correct / total)
    if epoch > 30:
        if acc > best_acc  :
            best_acc = acc
            print(acc)
            print('Saving..')
            import pickle as pkl
            with open('bstwindow:{}-repeat:{}.pkl'.format(time_window,args.repeat),'wb') as file:
                pkl.dump([loss_train_record,train_acc_record,loss_test_record,test_acc_record, snn.get_params()],file)
print(' best acc:', best_acc)
