import math
import os
import os.path
import random
from collections import defaultdict
from functools import partial

import brainpy as bp
import brainpy.math as bm
import cv2
import jax
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from torch.autograd import Variable
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import transforms

from config import load_config
from MovingMNIST import MovingMNIST
from tPCN import tPCN

yita=0.5
print(torch.__version__)
seq_size=32
channel=1

def transform_taxi(X):
    X = 1. * (X - X.min()) / (X.max() - X.min())
    # X = X * 2. - 1.
    return X

def get_rotating_mnist(datapath, seq_len, sample_size, batch_size, seed, angle):
    """digit: digit used to train the model
    
    test_digit: digit used to test the generalization of the model

    angle: rotating angle at each step
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train = datasets.MNIST(datapath, train=True, transform=transform, download=True)

    # randomly sample 
    dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
    # get data from particular classes
    # idx = (train.targets != test_digit).bool()
    # test_idx = (train.targets == test_digit).bool()
    train_data = train.data / 255.
    # test_data = train.data[test_idx] / 255.

    random.seed(seed)
    train_data = train_data[random.sample(range(len(train_data)), sample_size)] # [sample_size, h, w]
    # test_data = test_data[random.sample(range(len(test_data)), test_size)]
    h, w = train_data.shape[-2], train_data.shape[-1]
    # rotate images
    train_sequences = torch.zeros((sample_size, seq_len, h, w))

    for l in range(seq_len):
        train_sequences[:, l] = TF.rotate(train_data, angle * l)

    train_loader = DataLoader(DataWrapper(train_sequences), batch_size=batch_size)
    
    return train_loader

def seq_generate(input, seq_len):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        h_i = random.choice(range(4))
        w_i = random.choice(range(4))
        ob_seq.append(input[:, :, h_i * seq_size:(h_i + 1) * seq_size,
                            w_i * seq_size:(w_i + 1) * seq_size].reshape(-1,  channel* seq_size * seq_size))

        loc = bm.zeros(16)
        loc[h_i * 4 + w_i] = 1*yita
        action_seq.append(loc)
    return ob_seq, action_seq

def rio_generate(input, seq_len):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        ii=random.choice(range(8))
        ob_seq.append(input[:, ii, :,
                            :].reshape(-1,  channel* seq_size * seq_size))

        loc = bm.zeros(8)
        loc[ii] = 1*yita
        action_seq.append(loc)
    return ob_seq, action_seq


def test_rio_generate(input, seq_len,num):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        ii = random.choice(range(num))

        ob_seq.append(input[:, ii, :,:].reshape(-1,  channel* seq_size * seq_size))
        loc = bm.zeros(8)
        loc[ii] = 1*yita
        action_seq.append(loc)
    return ob_seq, action_seq


def make_phase(l_duration, duration, dt, L):
    phase = [0]
    l1 = duration // L * l_duration // dt
    l2 = (duration / L - duration // L * l_duration) // dt
    tmp = 0
    for i in range(L):
        tmp += l1
        phase.append(tmp)
        tmp += l2
        phase.append(tmp)
    phase = [int(x) for x in phase]
    #print(phase)
    return phase

def make_phase3(l_duration, duration, dt, L):
    phase = [0]
    #l1 = duration // L * l_duration // dt
    l2 = (duration / L - duration // L * l_duration) // dt
    tmp = 0
    for i in range(L):
        #tmp += l1
        #phase.append(tmp)
        tmp += l2
        phase.append(tmp)
    phase = [int(x) for x in phase]
    #print(phase)
    return phase


def show_grad(mon_list, num, phase):
    seq_len = len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[j]
            grad = mon[f'avg_grad{i}']
            time = bm.linspace(0, 1, len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
            plt.xticks([])
    plt.show()


def show_grad2(mon_list, num, phase):
    seq_len = 10
    length=len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[length-seq_len+j]
            grad = mon[f'avg_grad{i}'] + mon[f'avg_grad{i+1}']
            #print(grad)
            time = bm.linspace(0, 1, len(grad))
            #print(len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
            plt.xticks([])
    plt.savefig('error_4_64.png')

def show_grad3(mon_list, num, phase):
    seq_len = 20
    length=len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[length-seq_len+j]
            grad = mon[f'avg_grad{i}'] + mon[f'avg_grad{i+1}']
            #print(grad)
            time = bm.linspace(0, 1, len(grad))
            #print(len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
            plt.xticks([])
    plt.savefig('error_multi.png')


def dict2array(mon_list):
    seq = []
    seq_len = len(mon_list)
    #print(seq_len)
    for i in range(seq_len):
        mon = mon_list[i]
        seq_i = []
        for k, v in mon.items():
            seq_i.append(v)
        #for k, v in mon.items():
        #   for vv in v:
        #       for vvv in vv:
        #           seq_i.append(vvv)
        seq.append(seq_i)
    #print(seq)
    seq = bm.array(seq)
    return seq


def generate_num(tpcn, batch_size, input):
    ans = bm.zeros((batch_size, 3, 32, 32))
    no_input = bm.zeros((batch_size, 3 * seq_size * seq_size))

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    errors=[]
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        if i>num:
            errors.append(bm.square(ob_i-input[:,:,h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8].numpy()))
    errors=bm.array(errors)
    print(errors.shape)
    print(bm.mean(errors))

    plt.figure(figsize=(2 * 16, 4))
    for j in range(16):
        plt.subplot(2, 16, j + 16 + 1)
        img1 = bm.transpose(ans[j], (1, 2, 0))
        plt.imshow(img1)
        plt.xticks([])
        plt.yticks([])
        plt.subplot(2, 16, j + 1)
        img2 = input[j].permute(1, 2, 0)
        plt.imshow(img2)
        plt.xticks([])
        plt.yticks([])
    plt.title("seq_len=3000 neuron_size=[512,256,64] batch=16")
    plt.savefig('pic_cifar10/16/bitch16.png')
    plt.show()
    return bm.mean(errors, axis=(0, 2, 3,4))

def generate(tpcn, batch_size, input,ii):
    ans = bm.zeros((batch_size, channel, seq_size*4, seq_size*4))
    no_input = bm.zeros((batch_size, channel * seq_size * seq_size))

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    diff=[]
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * seq_size:(h_i + 1) * seq_size, w_i * seq_size:(w_i + 1) * seq_size] = ob_i

    for zz in range(4):
        z=zz
        plt.figure(figsize=(30, 6))
        for j in range(batch_size):
            diff_i=bm.transpose(ans[j], (1, 2, 0))-input[j].permute(1, 2, 0).numpy()
            diff.append(bm.square(diff_i).mean())
            if j<16*(z+1) and j>(16*z-1):
                plt.subplot(3, 16, j-16*z + 16 + 1)
                img1 = bm.transpose(ans[j], (1, 2, 0))
                plt.imshow(img1)
                plt.xticks([])
                plt.yticks([])
                plt.subplot(3, 16, j-16*z + 1)
                img2 = input[j].permute(1, 2, 0)
                plt.imshow(img2)
                plt.xticks([])
                plt.yticks([])

                plt.subplot(3, 16, j-16*z + 1+16*2)
                img3 = diff_i
                plt.imshow(img3)
                plt.xticks([])
                plt.yticks([])

        #diff=diff/batch_size
        #plt.title('diff='+str(diff)+' ;seq_len=2000, neuron_size = [500], batch=64')
        plt.savefig('pic_final/layer3/64/4'+str(z)+'_test6.pdf')
        #plt.show()
        print(diff)
        plt.close()
    return(diff)

def generate_rio(tpcn, batch_size, input,ii):
    ans = bm.zeros((batch_size,8, channel, seq_size, seq_size))
    no_input = bm.zeros((batch_size, channel * seq_size * seq_size))

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    diff=[]
    for i in range(8):
        action_i = bm.zeros(8)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        ans[:,i, :,  :, :] = ob_i

    for zz in range(8):
        z=zz
        plt.figure(figsize=(20, 5))
        for j in range(8):
            diff_i=bm.transpose(ans[zz,j], (1, 2, 0))-input[zz,j:j+1].permute(1, 2, 0).numpy()
            diff.append(bm.square(diff_i).mean())
            if j>-1:
                plt.subplot(3, 8, j + 8 + 1)
                img1 = bm.transpose(ans[zz,j], (1, 2, 0))
                plt.imshow(img1)
                plt.xticks([])
                plt.yticks([])
                plt.subplot(3, 8, j + 1)
                
                img2 = input[zz,j:j+1].permute(1, 2, 0)
                #print(input[zz,j:j+1].shape)
                plt.imshow(img2)
                plt.xticks([])
                plt.yticks([])

                plt.subplot(3, 8, j + 1+8*2)
                img3 = diff_i
                plt.imshow(img3)
                plt.xticks([])
                plt.yticks([])

        #diff=diff/batch_size
        #plt.title('diff='+str(diff)+' ;seq_len=2000, neuron_size = [500], batch=64')
        plt.savefig('pic_final/taxi/'+str(z)+'.pdf')
        #plt.show()
        #print(bm.mean(diff))
        plt.close()
    return(diff)

def generate_diff(tpcn, batch_size, input):
    ans = bm.zeros((batch_size, 3, 32, 32))
    no_input = bm.zeros((batch_size, 3 * 8 * 8))

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    diff=0
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i

    #plt.figure(figsize=(30, 6))
    for j in range(batch_size):
        diff_i=bm.transpose(ans[j], (1, 2, 0))-input[j].permute(1, 2, 0).numpy()
        diff=diff+abs(diff_i).mean()
        #if j<16:
        #    plt.subplot(3, 16, j + 16 + 1)
        #    img1 = bm.transpose(ans[j], (1, 2, 0))
        #    plt.imshow(img1)
        #    plt.xticks([])
        #    plt.yticks([])
        #    plt.subplot(3, 16, j + 1)
        #    img2 = input[j].permute(1, 2, 0)
        #    plt.imshow(img2)
        #    plt.xticks([])
        #    plt.yticks([])
        #    
        #    plt.subplot(3, 16, j + 1+16*2)
        #    img3 = diff_i
        #    plt.imshow(img3)
        #    plt.xticks([])
        #    plt.yticks([])
        
    diff=diff/batch_size
    #plt.title('diff='+str(diff)+' ;seq_len=2000, neuron_size = [1000 1000 1000], batch=128')
    #plt.savefig('pic_final/layer3/128/1.png')
    #plt.show()
    #print(diff)
    return(diff)

def generate_multi(tpcn, batch_size, input,ii):
    ans = bm.zeros((batch_size, channel, seq_size*4, seq_size*4))
    no_input = bm.zeros((batch_size, channel * seq_size * seq_size))
    diff=0

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    action0 = bm.zeros(16)
    action0[0] = 1*yita
    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
    #tpcn.test_init()
    phase3=make_phase3(l_duration, duration, dt, L)
    mon_list3=[]
    seq_len=300
    num=7
    ob_seq, action_seq = test_seq_generate(input, seq_len,num+1)
    for i in range(seq_len):
        if i%100==0:
            tpcn.eta=tpcn.eta*0.95
            tpcn.eta_s=tpcn.eta_s*0.95
        tpcn.next_predict(ob_seq[i], action_seq[i])
        mon3=tpcn.test_run()
        mon_list3.append(mon3)
        print(bm.mean(dict2array([mon3])))
    #show_grad3(mon_list3, L, phase3)
    ans[:, :, 0:seq_size, 0:seq_size] = input[:, :, 0:seq_size, 0:seq_size]
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        #print(bm.mean(dict2array([mon])))
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        w_i = (i) % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * seq_size:(h_i + 1) * seq_size, w_i * seq_size:(w_i + 1) * seq_size] = ob_i
    #print(tpcn.s[-1])
    errors=[]
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * seq_size:(h_i + 1) * seq_size, w_i * seq_size:(w_i + 1) * seq_size] = ob_i
        if i>num:
            errors.append(bm.square(ob_i-input[:,:,h_i * seq_size:(h_i + 1) * seq_size, w_i * seq_size:(w_i + 1) * seq_size].numpy()))
    errors=bm.array(errors)
    
    plt.figure(figsize=(30, 6))
    for j in range(batch_size):
        diff_i=bm.transpose(ans[j], (1, 2, 0))-input[j].permute(1, 2, 0).numpy()
        diff=diff+abs(diff_i).mean()
        if j<16*(0+1) and j>(16*0-1):
            plt.subplot(3, 16, j-16*0 + 16 + 1)
            img1 = bm.transpose(ans[j], (1, 2, 0))
            plt.imshow(img1)
            plt.xticks([])
            plt.yticks([])
            plt.subplot(3, 16, j-16*0 + 1)
            img2 = input[j].permute(1, 2, 0)
            plt.imshow(img2)
            plt.xticks([])
            plt.yticks([])
            plt.subplot(3, 16, j-16*0 + 1+16*2)
            img3 = diff_i
            plt.imshow(img3)
            plt.xticks([])
            plt.yticks([])
    print(errors[:,:16,:,:,:].shape)
    print(bm.mean(errors[:,:16,:,:,:]))
    diff=diff/batch_size
    print(diff)
    plt.title('test_multi:diff='+str(bm.mean(errors[:,:16,:,:,:]))+' ;seq_len=2000, neuron_size =[750], batch=128')
    plt.savefig('pic_final/fmnist/test_'+str(ii+1)+'.png')
    plt.close()
    return bm.mean(errors[:,:16,:,:,:])


def generate_multi_z(tpcn, batch_size, input,ii):
    ans = bm.zeros((batch_size,8, channel, seq_size, seq_size))
    no_input = bm.zeros((batch_size, channel * seq_size * seq_size))
    diff=0

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    action0 = bm.zeros(8)
    action0[0] = 1*yita
    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
    #tpcn.test_init()
    phase3=make_phase3(l_duration, duration, dt, L)
    mon_list3=[]
    seq_len=1000
    num=3
    ob_seq, action_seq = test_rio_generate(input, seq_len,num+1)
    for i in range(seq_len):
        if i%100==0:
            tpcn.eta=tpcn.eta*0.95
            tpcn.eta_s=tpcn.eta_s*0.95
        tpcn.next_predict(ob_seq[i], action_seq[i])
        mon3=tpcn.test_run()
        mon_list3.append(mon3)
        print(bm.mean(dict2array([mon3])))
    #show_grad3(mon_list3, L, phase3)
    #ans[:, :, 0:seq_size, 0:seq_size] = input[:, :, 0:seq_size, 0:seq_size]
    for i in range(8):
        action_i = bm.zeros(8)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        ans[:,i, :,  :, :] = ob_i
    #print(tpcn.s[-1])
    errors=[]
    for i in range(8):
        action_i = bm.zeros(8)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)

        ans[:,i, :, :,:] = ob_i
        if i>num:
            errors.append(bm.square(ob_i-input[:,i:i+1,:,:].numpy()))
    errors=bm.array(errors)
    for zz in range(30):
        z=zz
        plt.figure(figsize=(20, 5))
        for j in range(8):
            diff_i=bm.transpose(ans[zz,j], (1, 2, 0))-input[zz,j:j+1].permute(1, 2, 0).numpy()
            #diff.append(bm.square(diff_i).mean())
            if j>-1:
                plt.subplot(3, 8, j + 8 + 1)
                img1 = bm.transpose(ans[zz,j], (1, 2, 0))
                plt.imshow(img1)
                plt.xticks([])
                plt.yticks([])
                plt.subplot(3, 8, j + 1)
                #print(input[zz,j:j+1].shape)
                img2 = input[zz,j:j+1].permute(1, 2, 0)
                
                plt.imshow(img2)
                plt.xticks([])
                plt.yticks([])

                plt.subplot(3, 8, j + 1+8*2)
                img3 = diff_i
                plt.imshow(img3)
                plt.xticks([])
                plt.yticks([])

        #diff=diff/batch_size
        #plt.title('diff='+str(diff)+' ;seq_len=2000, neuron_size = [500], batch=64')
        #plt.savefig('pic_final/taxi/'+str(z)+'_test_multi.pdf')
        #plt.show()
        #print(bm.mean(diff))
        plt.close()
    return bm.mean(errors,(1,2,3,4))

def cross(t,p):
    delta=1e-7
    return -np.mean(t*np.log(p+delta)+(1-t)*np.log(1-p+delta))

def generate_moving(tpcn, batch_size, input,ii):
    ans = bm.zeros((batch_size,20, channel, seq_size, seq_size))
    no_input = bm.zeros((batch_size, channel * seq_size * seq_size))
    diff=0

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    action0 = bm.zeros(20)
    action0[0] = 1*yita
    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
    #tpcn.test_init()
    phase3=make_phase3(l_duration, duration, dt, L)
    mon_list3=[]
    seq_len=200
    num=9
    ob_seq, action_seq = test_rio_generate(input, seq_len,num+1)
    for i in range(seq_len):
        if i%100==0:
            tpcn.eta=tpcn.eta*0.95
            tpcn.eta_s=tpcn.eta_s*0.95
        tpcn.next_predict(ob_seq[i], action_seq[i])
        mon3=tpcn.test_run()
        mon_list3.append(mon3)
        print(bm.mean(dict2array([mon3])))
    #show_grad3(mon_list3, L, phase3)
    #ans[:, :, 0:seq_size, 0:seq_size] = input[:, :, 0:seq_size, 0:seq_size]
    for i in range(20):
        action_i = bm.zeros(20)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        ans[:,i, :,  :, :] = ob_i
    #print(tpcn.s[-1])
    errors=[]
    for i in range(20):
        action_i = bm.zeros(20)
        action_i[i] = 1*yita

        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, channel, seq_size, seq_size)

        ans[:,i, :, :,:] = ob_i
        if i>num:
            errors.append(cross(ob_i,input[:,i:i+1,:,:].numpy()))
    errors=bm.array(errors)
    for zz in range(1):
        z=zz
        plt.figure(figsize=(30, 5))
        for j in range(20):
            diff_i=bm.transpose(ans[zz,j], (1, 2, 0))-input[zz,j:j+1].permute(1, 2, 0).numpy()
            #diff.append(bm.square(diff_i).mean())
            if j>-1:
                plt.subplot(3, 20, j + 20 + 1)
                img1 = bm.transpose(ans[zz,j], (1, 2, 0))
                plt.imshow(img1)
                plt.xticks([])
                plt.yticks([])
                plt.subplot(3, 20, j + 1)
                img2 = input[zz,j:j+1].permute(1, 2, 0)
                plt.imshow(img2)
                plt.xticks([])
                plt.yticks([])

                plt.subplot(3, 20, j + 1+20*2)
                img3 = diff_i
                plt.imshow(img3)
                plt.xticks([])
                plt.yticks([])

        #diff=diff/batch_size
        #plt.title('diff='+str(diff)+' ;seq_len=2000, neuron_size = [500], batch=64')
        #plt.savefig('pic_final/moving/cross/'+str(ii)+'.pdf')
        #plt.show()
        #print(bm.mean(diff))
        plt.close()
        print(bm.mean(errors))
    return bm.mean(errors)
    
def draw_s(tpcn, batch_size, input,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 8 * 8))
    # 设置画布大小和子图布局
    xx = 16
    yy=tpcn.neuron_size[layer_index]//8

    # 循环绘制原始图像和生成的图像
    
    neuron_firing=bm.zeros((batch_size,16,yy))
    fig, axes = plt.subplots(16, xx, figsize=(50, 30))
    ans = bm.zeros((batch_size, 3, 32, 32))
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            neuron_firing[:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    for i in range(16):
        for j in range(16):
            axes[i,j].plot(range(yy),neuron_firing[i,j,:])

    plt.savefig('pic_cifar10/128/s/s_'+str(layer_index)+'_'+str(neuron_index)+'_yita01_128.png')
    
def draw_out(tpcn, batch_size, input,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 8 * 8))
    # 设置画布大小和子图布局
    xx = 16
    #yy=8*8*3
    yy=tpcn.neuron_size[layer_index]//8

    # 循环绘制原始图像和生成的图像
    
    neuron_firing=bm.zeros((2,batch_size,16,yy))
    fig, axes = plt.subplots(16, xx, figsize=(50, 30))
    ans = bm.zeros((batch_size, 3, 32, 32))
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        print(bm.mean(abs(tpcn.s[1][1,:]-tpcn.s[1][2,:])))
        print(bm.mean(abs(tpcn.e[0][1,:]-tpcn.e[0][2,:])))
        #ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            #neuron_firing[0,:,i,x]=-1 * tpcn.e[0][:,x]
            neuron_firing[0,:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,1,:])))
    print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,2,:])))
    print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,3,:])))
    print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,4,:])))
    print(bm.mean(abs(neuron_firing[0,:,1,:]-neuron_firing[0,:,2,:])))
    for i in range(16):
        for j in range(16):
            axes[i,j].plot(range(yy),neuron_firing[0,i,j,:])
    tpcn.init_neuron(batch_size)
    generate_multi(tpcn, batch_size, input)
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        #ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            #neuron_firing[1,:,i,x]=-1 * tpcn.e[0][:,x]
            neuron_firing[1,:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    print(bm.mean(abs(neuron_firing[0,:,:,:]-neuron_firing[1,:,:,:])))
    for i in range(16):
        for j in range(16):
            axes[i,j].plot(range(yy),neuron_firing[1,i,j,:],color='r')
    

    plt.savefig('pic_cifar10/128/s/diff_'+str(layer_index)+'_'+str(neuron_index)+'_yita01_128.png')
    
def draw_out_loc(tpcn, batch_size, input,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 8 * 8))
    # 设置画布大小和子图布局
    xx = 16
    yy=8*8*3
    #yy=tpcn.neuron_size[layer_index]//8

    # 循环绘制原始图像和生成的图像
    
    neuron_firing=bm.zeros((2,batch_size,16,yy))
    fig, axes = plt.subplots(16, xx, figsize=(50, 30))
    ans = bm.zeros((batch_size, 3, 32, 32))
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = yita
        tpcn.next_predict(no_input, action_i)
        #print(bm.mean(abs(tpcn.s[1][1,:]-tpcn.s[1][2,:])))
        #print(bm.mean(abs(tpcn.e[0][1,:]-tpcn.e[0][2,:])))
        #ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            neuron_firing[0,:,i,x]=-1 * tpcn.e[0][:,x]
            #neuron_firing[0,:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    #print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,1,:])))
    #print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,2,:])))
    #print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,3,:])))
    #print(bm.mean(abs(neuron_firing[0,:,0,:]-neuron_firing[0,:,4,:])))
    #print(bm.mean(abs(neuron_firing[0,:,1,:]-neuron_firing[0,:,2,:])))
    for i in range(16):
        for j in range(16):
            axes[i,j].plot(range(batch_size),neuron_firing[0,:,j,i])
    tpcn.init_neuron(batch_size)
    generate_multi(tpcn, batch_size, input)
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        #ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            neuron_firing[1,:,i,x]=-1 * tpcn.e[0][:,x]
            #neuron_firing[1,:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    #print(bm.mean(abs(neuron_firing[0,:,:,:]-neuron_firing[1,:,:,:])))
    for i in range(16):
        for j in range(16):
            axes[i,j].plot(range(batch_size),neuron_firing[1,:,j,i],color='r')
    

    plt.savefig('pic_cifar10/128/s/diff_'+str(layer_index)+'_'+str(neuron_index)+'_b.png')
    
def draw_b(tpcn, batch_size, input,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 8 * 8))
    # 设置画布大小和子图布局
    xx = 16
    yy=16

    # 循环绘制原始图像和生成的图像
    
    neuron_firing=bm.zeros((batch_size,16,16))
    fig, axes = plt.subplots(16, xx, figsize=(40, 30))
    ans = bm.zeros((batch_size, 3, 32, 32))
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        
        for x in range(yy):
            neuron_firing[:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            axes[x,i].plot(range(batch_size),neuron_firing[:,i,x])
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)

    plt.savefig('pic_cifar10/128/s/neuron_'+str(layer_index)+'_'+str(neuron_index)+'.png')

def draw_feature(tpcn, batch_size, input,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 8 * 8))
    # 设置画布大小和子图布局
    xx = 16
    yy=16

    # 循环绘制原始图像和生成的图像
    diff=0
    neuron_firing=bm.zeros((16,16,yy))
    fig, axes = plt.subplots(yy, xx, figsize=(40, 30))
    ans = bm.zeros((batch_size, 3, 32, 32))
    for i in range(16):
        action_i = bm.zeros(16)
        action_i[i] = 1*yita
        tpcn.next_predict(no_input, action_i)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 8, 8)
        w_i = i % 4
        h_i = int((i - w_i) / 4)
        ans[:, :, h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8] = ob_i
        
        for x in range(yy):
            neuron_firing[:,i,x]=tpcn.s[layer_index][:,neuron_index+x]
            #ob_i = (-1 * tpcn.e).reshape(-1, 3, 60, 80)
    diff=0
    for j in range(16):
        diff_i=bm.transpose(ans[j], (1, 2, 0))-input[j].permute(1, 2, 0).numpy()
        diff=diff+abs(diff_i).mean()
        for x in range(yy):
            axes[x,j].plot(range(16),neuron_firing[j,:,x])
    diff=diff/batch_size
    fig.suptitle('diff='+str(diff))
    plt.savefig('pic_cifar10/16/s/firing_'+str(layer_index)+'_'+str(neuron_index)+'_test.png')
            
        
    #for i, (input, labely,labelx) in enumerate(test_dataloader):
    #    # 绘制原始图像
    #    if int(labely[0])%2==1 or int(labelx[0])%2==1:
    #        continue
    #    labelx = int(labelx[0])//2 #[0-67]
    #    labely=int(labely[0])//2 #[0-25]
    #    action = np.zeros(68*26)
    #    frag=labelx+68*labely
    #    #print(frag)
    #    action[frag]=1
    #    tpcn.next_predict(no_input, action)
    #    #print(tpcn.s[1][0].sum())
    #    neuron_firing[frag]=tpcn.s[layer_index][0][neuron_index]
    #    #print(tpcn.s[layer_index])
    #    ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
    #    diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
    #    #print(diff_i.mean())
    #    diff=diff+diff_i.mean()
#
    #diff=diff/68/26
    #print(diff)
    ##neuron_fft=np.fft.fft(neuron_firing)/68/26
    ##plt.plot(range(len(neuron_fft)),abs(neuron_fft))
    #plt.plot(range(len(neuron_firing)),neuron_firing)
    #neuron_mean=neuron_firing.mean()
    #neuron_var=neuron_firing.var()
    #plt.title(f'Mean: {neuron_mean:.6f} , Var: {neuron_var:.6f}')
    ## 调整子图布局
    ##plt.tight_layout()
#
    ## 显示画布
    ## plt.show()
    #plt.savefig('pic_68_26/firing_'+str(layer_index)+'_'+str(neuron_index)+'_fft.png')
    
def draw_theta(tpcn, batch_size, start_neuron, neuron_layer):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 16
    fig, axes = plt.subplots(4, xx, figsize=(20, 5))
    #ans=bm.zeros(( 32, 32,3))
    for i in range(xx):
        frag=i+start_neuron
        tpcn.s[neuron_layer][:,:]=bm.zeros((batch_size,tpcn.neuron_size[neuron_layer]))
        tpcn.s[neuron_layer][:,frag]=bm.ones(batch_size)
        for j in range(neuron_layer-1, 0, -1):
            tpcn.s[j].value = tpcn.f(tpcn.s[j + 1]) @ tpcn.theta[j]
        ob_i=(tpcn.f(tpcn.s[1]) @ tpcn.theta[0]).reshape(-1,3,8,8)
        ob_i=bm.transpose(ob_i[0], (1, 2, 0))
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8,:] = ob_i
        axes[0, i].imshow(ob_i)
        #axes[0,i].set_title('Original Image')
        ob_i_red=ob_i[:,:,0]
        axes[1, i].imshow(ob_i_red,cmap='Reds')
        #axes[1,i].set_title('Red Channel')
        ob_i_green=ob_i[:,:,1]
        axes[2, i].imshow(ob_i_green,cmap='Greens')
        #axes[2,i].set_title('Green Channel')        
        ob_i_blue=ob_i[:,:,2]
        axes[3, i].imshow(ob_i_blue,cmap='Blues')
        #axes[3,i].set_title('Blue Channel') 
    #axes[0,0].imshow(ans)       
    

    plt.savefig('pic_cifar10/128/theta/thetalist_'+str(start_neuron)+'_layer'+str(neuron_layer)+'_128.png')
    
    
def draw_theta_gray(tpcn, batch_size, start_neuron, neuron_layer):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 16
    fig, axes = plt.subplots(4, xx, figsize=(20, 5))
    #ans=bm.zeros(( 32, 32,3))
    for i in range(xx):
        frag=i+start_neuron
        tpcn.s[neuron_layer][:,:]=bm.zeros((batch_size,tpcn.neuron_size[neuron_layer]))
        tpcn.s[neuron_layer][:,frag]=bm.ones(batch_size)
        for j in range(neuron_layer-1, 0, -1):
            tpcn.s[j].value = tpcn.f(tpcn.s[j + 1]) @ tpcn.theta[j]
        ob_i=(tpcn.f(tpcn.s[1]) @ tpcn.theta[0]).reshape(-1,3,8,8)
        ob_i=bm.transpose(ob_i[0], (1, 2, 0)).numpy()
        #w_i = i % 4
        #h_i = int((i - w_i) / 4)
        #ans[h_i * 8:(h_i + 1) * 8, w_i * 8:(w_i + 1) * 8,:] = ob_i
        #print(type(ob_i))
        #print(ob_i.shape)
        image_gray = cv2.cvtColor(ob_i, cv2.COLOR_BGR2GRAY)
        axes[0, i].imshow(image_gray,cmap='gray')
        #axes[0,i].set_title('Original Image')
        ob_i_red=ob_i[:,:,0]
        axes[1, i].imshow(ob_i_red,cmap='gray')
        #axes[1,i].set_title('Red Channel')
        ob_i_green=ob_i[:,:,1]
        axes[2, i].imshow(ob_i_green,cmap='gray')
        #axes[2,i].set_title('Green Channel')        
        ob_i_blue=ob_i[:,:,2]
        axes[3, i].imshow(ob_i_blue,cmap='gray')
        #axes[3,i].set_title('Blue Channel') 
    #axes[0,0].imshow(ans)       
    

    plt.savefig('pic_cifar10/128/theta/thetalist_gray_'+str(start_neuron)+'_layer'+str(neuron_layer)+'_128.png')

if __name__ == '__main__':
    config = 'taxib'  ### 'CIFAR10'  'Fashionmnist'
    State = 4
    batch_size = 128
    seq_len = 1000
    ob_size = channel * seq_size * seq_size
    action_size = 8
    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

    # training_Epoch
    Epoch = 1

    if torch.cuda.is_available():
        bm.set_platform('gpu')
        print('use gpu')
    else:
        print('use cpu')
    #bm.set_platform('cpu')
    print(State)
    # hee para
    neuron_size, l_duration, duration, eta, dt, f, noise, \
    model_name, using_epoch, training_data, normalize, inv_normalize, _transforms = load_config(config)

    # load_data
    #data_lenth = len(training_data)
    #root = './data'
    #if not os.path.exists(root):
    #    os.mkdir(root)


    #train_set = MovingMNIST(root='.data/mnist', train=True, download=True)
#
    #test_set = MovingMNIST(root='.data/mnist', train=False, download=True)
    data_lenth = len(training_data)
    train_dataloader = DataLoader(training_data,
                                  batch_size=batch_size,
                                  shuffle=False)
    
    #test_dataloader = torch.utils.data.DataLoader(
    #                 dataset=test_set,
    #                 batch_size=batch_size,
    #                 shuffle=False)
    #print(len(test_dataloader))
    for input in train_dataloader:
        #input=transforms.ToTensor(input)
        print(input.shape)
        input=transform_taxi(input)
        for i in range(8):
            plt.imshow(input[0,i,::])
            plt.savefig(str(i)+'.png')
        image_shape = input.shape[2:]
        image_size = 1
        #print(input)
        for j in image_shape:
            image_size *= j
        break
    #for input, label in test_dataloader:
    #    print(input.shape)
    print('image_shape = ', image_shape)
    print('image_size = ', image_size)
    print('data_length = ', data_lenth)
    print('simulation_length = ', int(duration / dt))

    # model
    neuron_size[0] = ob_size
    L = len(neuron_size) - 1
    print('pcn_ob_size = ', neuron_size[0])
    tpcn = tPCN(neuron_size,
                action_size,
                eta,
                l_duration,
                duration,
                f,
                noise=noise,
                dt=dt)
    phase = make_phase(l_duration, duration, dt, L)

    if State == 0:
        seq_len=2
        if os.path.exists('./modelcif/' + model_name + '_' + str(using_epoch) +
                          '_big.bp'):
            states = bp.checkpoints.load_pytree('./modelcif/' + model_name + '_' +
                                                str(using_epoch) + '_big.bp')
            tpcn.load_state_dict(states)

        for epoch in range(Epoch):
            p = 0
            for input, label in train_dataloader:
                ob_seq, action_seq = seq_generate(input, seq_len)
                tpcn.init_neuron(batch_size)
                mon_list = []
                for i in range(seq_len):
                    tpcn.next_predict(ob_seq[i], action_seq[i])
                    mon_list.append(tpcn.run())
                
                p += batch_size / data_lenth
                print(epoch+28, ' : ', int(p * 100), '%, grad=',
                       bm.mean(dict2array(mon_list)))
                if p >= 0.01:
                     break
                # print('grad=', bm.as_numpy(bm.mean(dict2array(mon_list), axis=(0, 2))))
                # show_grad(mon_list, L + 1, phase)
                # show_grad2(mon_list, L, phase)
            bp.checkpoints.save_pytree(
                './modelcif/' + model_name + '_' + str(epoch+28) + '_big.bp',
                tpcn.state_dict())
    # 训练
    if State == 1:
        seq_len = 2000

        count = 0
        error_list=[]
        for iii in range(1):
            ii=iii
            count = 0
            #tpcn = tPCN(neuron_size,
            #    action_size,
            #    eta,
            #    l_duration,
            #    duration,
            #    f,
            #    noise=noise,
            #    dt=dt)
            flag=1
            for input in train_dataloader:
                #print(torch.sum(input[0][0]))
                
                if flag<1:
                    flag+=1
                    continue
                if flag==2:
                    break
                
                input0 =transform_taxi(input)
                #print(torch.sum(input0[0][0]))
                print(input0.shape)
                ob_seq, action_seq = rio_generate(input0, seq_len)
                tpcn.init_neuron(batch_size)
                mon_list = []
                for i in range(seq_len):
                    if i%100==0:
                        tpcn.eta=tpcn.eta*0.95
                        tpcn.eta_s=tpcn.eta_s*0.95
                    tpcn.next_predict(ob_seq[i], action_seq[i])
                    mon = tpcn.run()
                    mon_list.append(mon)
                    count += 1
                    print(count, '/', seq_len, ': grad =',
                          bm.mean(dict2array([mon])))

                bp.checkpoints.save_pytree(
                    'model_final/taxi/'+model_name+'_'+str(ii)+'.bp',
                    tpcn.state_dict())
                error_list=generate_rio(tpcn, batch_size, input0,ii)
                #plt.plot(np.arange(48), tpcn.s[-1][0,16:])
                #plt.plot(np.arange(48), tpcn.s[-1][1,16:])
                #plt.plot(np.arange(48), tpcn.s[-1][3,16:])
                #plt.show()

                # print('grad=', bm.as_numpy(bm.mean(dict2array(mon_list), axis=(0, 2))))
                # show_grad(mon_list, L + 1, phase)
                show_grad2(mon_list, L, phase)
                flag+=1
        #np.save("pic_final/layer3/64/train_1500.npy", error_list)
    #print_loss123
    if State == 2:
        seq_len = 2000

        count = 0
        for input, label in train_dataloader:
            ob_seq, action_seq = seq_generate(input, seq_len)
            tpcn.init_neuron(batch_size)
            mon_list = []
            error_layers=bm.zeros((2,seq_len))
            loss=[]
            for i in range(seq_len):
                if i%100==0:
                    tpcn.eta=tpcn.eta*0.95
                    tpcn.eta_s=tpcn.eta_s*0.95
                tpcn.next_predict(ob_seq[i], action_seq[i])
                mon = tpcn.run()
                mon_list.append(mon)
                mon=dict2array([mon])
                #print(mon)
                error_layers[0,i]=mon[0][0][29]
                error_layers[1,i]=mon[0][1][128]
                
              
                
                count += 1
                grad=bm.mean(mon)
                print(count, '/', seq_len, ': grad =',
                      grad)
                loss.append(grad)
            #show_grad2(mon_list, L, phase)
            plt.yscale("log")
            plt.plot(range(seq_len),error_layers[0,:])
            plt.plot(range(seq_len),error_layers[1,:])
            #plt.plot(range(seq_len),error_layers[2,:])
            
            #
            plt.savefig('pic_cifar64/layer2/error/loss_layer2_64.png')
            plt.close()
            bp.checkpoints.save_pytree(
                './model_cifar64/' + model_name + '_' + str(seq_len) + '_2.bp',
                tpcn.state_dict())
            generate(tpcn, batch_size, input)
            np.save("cifar64_layer2.npy", error_layers)
            #plt.close()
            #plt.plot(range(seq_len),loss)
            #plt.savefig('pic_cifar10/loss/loss_128_yita01_128.png')
            break
            
    #test
    if State == 3:
        start_neuron=128*0
        neuron_layer = 1
        if os.path.exists('model_final/layer3/64/CIFAR10_64_net_layer3_2000_1.bp'):
            states = bp.checkpoints.load_pytree('model_final/layer3/64/CIFAR10_64_net_layer3_2000_1.bp')
            for input, label in train_dataloader:
                tpcn.init_neuron(batch_size)
                tpcn.load_state_dict(states)
                #tpcn.s[-1][:,16:]=0
                #print(tpcn.s[-1][0,16:])
                count = 0
                #draw_b(tpcn, batch_size, input,neuron_layer,start_neuron)
                #draw_s(tpcn, batch_size, input,neuron_layer,start_neuron)
                #draw_feature(tpcn, batch_size, input,neuron_layer,start_neuron)
                #draw_theta_gray(tpcn, batch_size, start_neuron, neuron_layer)
                #draw_theta(tpcn, batch_size, start_neuron, neuron_layer)
                generate(tpcn,batch_size,input,0)
                break
            #for input, label in train_dataloader:
            #    ob_seq, action_seq = test_seq_generate(input, seq_len ,12)
            #    tpcn.init_neuron(batch_size)
            #    #action0 = bm.zeros(16)
            #    #action0[0] = 1
            #    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
            #    #tpcn.test_init()
            #    for i in range(seq_len):
            #        tpcn.next_predict(ob_seq[i], action_seq[i])
            #        mon = tpcn.test_run()
            #        #print(count, '/', seq_len, ': grad =',
            #        #      bm.mean(dict2array([mon])))
            #    generate(tpcn, batch_size, input)
            #    break
                #plt.plot(np.arange(48), tpcn.s[-1][0,16:])
                #if cont ==3:
                #    plt.show()
    #test_1
    if State == 4:
        seq_len = 2000
        #test_dataloader = get_rotating_mnist('./data', 20, 128, 128, 1, 18)
        #if os.path.exists('model_final/layer2/128/CIFAR10_64_net_layer2_2000_1.bp'):
        
        for i in range(10):
            error_list=[]
            states = bp.checkpoints.load_pytree('model_final/taxi/taxib_net_0.bp')
            tpcn.init_neuron(batch_size)
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1][0,16:])
            count = 0
            flag=0
            for input in train_dataloader:
                #input0 = torch.concatenate((input,target),dim = 1)/256
                #input0 = F.interpolate(input0, size=(seq_size, seq_size), mode='bilinear', align_corners=False)
                input0 = transform_taxi(input)
                #print(input0[0,0,30])
                if flag<1:
                    flag+=1
                    continue
                if flag==2:
                    break
                tpcn.init_neuron(batch_size)
                error_list.append(generate_multi_z(tpcn, batch_size, input0,i))
                plt.close()
                #draw_out(tpcn, batch_size, input,4,0)
                flag+=1
            np.save('taxibj1000_'+str(i)+'.npy', error_list)
            print(np.mean(error_list))
    #loss
    if State==5:
        seq_len = 2000
        error=0
        if os.path.exists('model_final/layer3/32/CIFAR10_32_net_3000_test.bp'):
            states = bp.checkpoints.load_pytree('model_final/layer3/32/CIFAR10_32_net_3000_test.bp')
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1][0,16:])
            count = 0
            for input, label in train_dataloader:
                error_list=[]
                for num in range(15):
                    seq_len=2000+100*num
                    ob_seq, action_seq = test_seq_generate(input, seq_len,num+1)
                    tpcn.init_neuron(batch_size)
                    #action0 = bm.zeros(16)
                    #action0[0] = 1
                    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
                    #tpcn.test_init()
                    
                    for i in range(seq_len):
                        tpcn.next_predict(ob_seq[i], action_seq[i])
                        mon = tpcn.test_run()
                        #print(count, '/', seq_len, ': grad =',
                        #      bm.mean(dict2array([mon])))
                    error_list.append(generate_num(tpcn, batch_size, input))
                    plt.close()
                #print(error_list.shape)
                #plt.plot(range(15),error_list)
                #plt.savefig('errors.png')
                np.save('errors.npy',error_list)
                break
    #diff_n
    if State ==6:
        seq_len=2000
        diff_list=bm.zeros(64)
        for j in range(64):
            neuron_size[4]=16+j*32
            #print(neuron_size)
            tpcn_i = tPCN(neuron_size,
                action_size,
                eta,
                l_duration,
                duration,
                f,
                noise=noise,
                dt=dt)
            count=0
            for input, label in train_dataloader:
                ob_seq, action_seq = seq_generate(input, seq_len)
                tpcn_i.init_neuron(batch_size)
                mon_list = []
                error_layers=bm.zeros((3,3,seq_len))
                loss=[]
                for i in range(seq_len):
                    #if i%100==0:
                    #    tpcn.eta=tpcn.eta*0.95
                    #    tpcn.eta_s=tpcn.eta_s*0.95
                    tpcn_i.next_predict(ob_seq[i], action_seq[i])
                    mon = tpcn_i.run()
                    mon_list.append(mon)
                    mon=dict2array([mon])
                    count += 1
                    grad=bm.mean(mon)
                    loss.append(grad)
                    print(count, '/', seq_len, ': grad =',
                        grad)
                #bp.checkpoints.save_pytree('./modelcif/' + model_name + '_' + str(seq_len) +'_'+str(128+j*16) +'_16_error.bp',tpcn_i.state_dict())
                diff_list[j]=generate(tpcn_i, batch_size, input)
                break
            plt.close()
            plt.figure()
            plt.plot(range(seq_len),loss)
            plt.savefig('pic_cifar10/loss/diff_128_'+str(16+j*32)+'.png')
        plt.close()
        plt.figure()
        plt.plot(range(64),diff_list)
        plt.title('seq=5000,layer3:16-2048')
        plt.savefig('pic_cifar10/error/diff_n_128.png')

    if State==7:
        seq_len = 2000

        count = 0
        loss=[]
        for input, label in train_dataloader:
            ob_seq, action_seq = seq_generate(input, seq_len)
            tpcn.init_neuron(batch_size)
            mon_list = []
            for i in range(seq_len):
                if i%100==0:
                    tpcn.eta=tpcn.eta*0.95
                    tpcn.eta_s=tpcn.eta_s*0.95
                tpcn.next_predict(ob_seq[i], action_seq[i])
                mon = tpcn.run()
                mon_list.append(mon)
                count += 1
                print(count, '/', seq_len, ': grad =',
                      bm.mean(dict2array([mon])))
                loss.append(generate_diff(tpcn, batch_size, input))
            plt.plot(range(seq_len),loss)
            plt.savefig('diff.png')
            np.save("pic_final/layer2/64/mse.npy", loss)
            #bp.checkpoints.save_pytree(
            #    'model_final/layer3/128/' + model_name + '_' + str(seq_len) + '_1.bp',
            #    tpcn.state_dict())
            #generate(tpcn, batch_size, input)
            break
        
    if State == 8:
        seq_len = 2000
        test_dataloader = get_rotating_mnist('./data', 20, 1024, 128, 1, 18)
        #if os.path.exists('model_final/layer2/128/CIFAR10_64_net_layer2_2000_1.bp'):
        error_list=[]
        for i in range(1):
            states = bp.checkpoints.load_pytree('model_final/rio/MNIST_rio_net_layer3_0.bp')
            tpcn.init_neuron(batch_size)
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1][0,16:])
            count = 0
            flag=1
            for input, label in test_dataloader:
                if flag<1:
                    flag+=1
                    continue
                if flag==2:
                    break
                tpcn.init_neuron(batch_size)
                error_list=generate_multi_z(tpcn, batch_size, input,i)
                plt.close()
                #draw_out(tpcn, batch_size, input,4,0)
                flag+=1