# 10.31 先尝试empty_room横轴16步图片的学习
import math
import os
import random
from collections import defaultdict
from functools import partial

import brainpy as bp
import brainpy.math as bm
import jax
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from config import load_config
from tPCN import tPCN


def seq_generate(input, seq_len):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        h_i = random.choice(range(4))
        w_i = random.choice(range(4))
        ob_seq.append(input[:, :, h_i * 8:(h_i + 1) * 8,
                            w_i * 8:(w_i + 1) * 8].reshape(-1, 3 * 8 * 8))

        loc = bm.zeros(16)
        loc[h_i * 4 + w_i] = 1
        action_seq.append(loc)
    return ob_seq, action_seq


def test_seq_generate(input, seq_len,num):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        w_i = random.choice(range(num))
        h_i = w_i//4
        w_i -= h_i*4
        ob_seq.append(input[:, :, h_i * 8:(h_i + 1) * 8,
                            w_i * 8:(w_i + 1) * 8].reshape(-1, 3 * 8 * 8))

        loc = bm.zeros(16)
        loc[h_i*4+w_i] = 1
        action_seq.append(loc)
    return ob_seq, action_seq


def make_phase(l_duration, duration, dt, L):
    phase = [0]
    l1 = duration // L * l_duration // dt
    l2 = (duration / L - duration // L * l_duration) // dt
    tmp = 0
    for i in range(L):
        tmp += l1
        phase.append(tmp)
        tmp += l2
        phase.append(tmp)
    phase = [int(x) for x in phase]
    #print(phase)
    return phase


def show_grad(mon_list, num, phase):
    seq_len = len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[j]
            grad = mon[f'avg_grad{i}']
            time = bm.linspace(0, 1, len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
            plt.xticks([])
    plt.show()


def show_grad2(mon_list, num, phase):
    seq_len = 10
    length=len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[length-seq_len+j]
            grad = mon[f'avg_grad{i}'] + mon[f'avg_grad{i+1}']
            #print(grad)
            time = bm.linspace(0, 1, len(grad))
            #print(len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
            plt.xticks([])
    plt.savefig('error_450.png')


def dict2array(mon_list):
    seq = []
    seq_len = len(mon_list)
    #print(seq_len)
    for i in range(seq_len):
        mon = mon_list[i]
        seq_i = []
        for k, v in mon.items():
            seq_i.append(v)
        #for k, v in mon.items():
        #   for vv in v:
        #       for vvv in vv:
        #           seq_i.append(vvv)
        seq.append(seq_i)
    #print(seq)
    seq = bm.array(seq)
    return seq


def generate(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 12
    fig, axes = plt.subplots(3, xx, figsize=(12, 4))

    # 循环绘制原始图像和生成的图像
    diff=0
    for i, (input, labely,labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        if int(labely[0])%10!=2 or int(labelx[0])%7!=0:
            continue
        labelx = int(labelx[0])//7 #[0-19]
        labely=int(labely[0])//10 #[0-4]
        action = np.zeros(100)
        action[labelx+20*labely]=1
        tpcn.next_predict(no_input, action)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
        diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
        #print(diff_i.mean())
        diff=diff+diff_i.mean()
        if (labelx % 5)==0 and (labely%2)==0:
            fragx=labelx//5
            fragy=labely//2
            frag=fragx+4*fragy
            #print(frag)
            axes[0, frag].imshow(input[0].permute(1, 2, 0))
            #print(input[0].permute(1, 2, 0))
            axes[0, frag].set_title(f'O {frag}')
            axes[0, frag].axis('off')
    
            # 绘制生成的图像
            axes[1, frag].imshow(bm.transpose(ob_i[0], (1, 2, 0)))
            axes[1, frag].set_title(f'G {frag}')
            axes[1, frag].axis('off')
            
            # 绘制diff
            #print(diff.mean())
            axes[2, frag].imshow(diff_i)
            #print(diff_i.mean())
            axes[2, frag].set_title(f'D {frag}')
            axes[2, frag].axis('off')
    diff=diff/100
    print(diff)
    fig.suptitle(f'Average Difference: {diff:.6f}')
    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    # plt.show()
    plt.savefig('pic_final/lab/layer3/100_40_700.png')
    
def draw_s(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 12
    fig, axes = plt.subplots(7, xx, figsize=(20, 8))

    # 循环绘制原始图像和生成的图像
    for i, (input, labely,labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        if int(labely[0])%2==1 or int(labelx[0])%2==1:
            continue
        labelx = int(labelx[0])//2
        labely=int(labely[0])//2
        #print(frag)
        # print(fragx)
        if (labely % 9)==0 and (labelx%17)==0:
            fragy=labely//9 #[0,1,2]
            fragx=labelx//17 #[0,1,2,3]
            frag=fragx+4*fragy
            action = np.zeros(26*68)
            action[labelx+68*labely]=1
            tpcn.next_predict(no_input, action)
            axes[0,frag].plot(range(len(tpcn.e[0][0])),input.reshape(-1, 3 * 60 * 80)[0].numpy())
            axes[1,frag].plot(range(len(tpcn.e[0][0])),-1 * tpcn.e[0][0])
            axes[2,frag].plot(range(len(tpcn.e[0][0])),input[0].numpy().flatten()+tpcn.e[0][0])
            print(abs(input[0].numpy().flatten()+tpcn.e[0][0]).mean())
            #print(len(tpcn.e[0]))
            #print(-1 * tpcn.e[0])
            for s_i in range(4):
                axes[s_i+3,frag].plot(range(len(tpcn.s[s_i+1][0])),tpcn.s[s_i+1][0])
            #print(tpcn.s[4][0])

    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    # plt.show()
    plt.savefig('pic_68_26/68_26_s_test.png')


def generate_multi(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx=16
    lenn=1000
    fig, axes = plt.subplots(2, xx, figsize=(12, 4))

    # 循环绘制原始图像和生成的图像
    for i, (input, labely,labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        flagx=int(labelx[0])
        flagy=int(labely[0])
        if flagx>134-20 and flagy==30:
            axes[0, flagx-115].imshow(input[0].permute(1, 2, 0))
            axes[0, flagx-115].set_title(f'O {flagx}')
            axes[0, flagx-115].axis('off')

    #print(tpcn.s[3][:, 0:16])
    #print(tpcn.s[3][:, 16:32])
    #print(tpcn.s[2][:, 16:22])
    #print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    action0 = bm.zeros(135+51)
    action0[30] = 1
    action0[51+134]=1

    for input,labely,labelx in test_dataloader:
        if int(labely[0])==30 and int(labelx[0])==134:
            tpcn.next_predict(input.reshape(-1, 3 * 60 * 80), action0)
            tpcn.test_init()
            count=0
            mon_list=[]
            for i in range(lenn):
                tpcn.next_predict(input.reshape(-1, 3 * 60 * 80),
                                  action0)
                mon=tpcn.test_run()
                mon_list.append(mon)
                count += 1
                print(count, '/', lenn, ': grad =',
                          bm.mean(dict2array([mon])))
            break
    #ans[:, :, 0:8, 0:8] = input[:, :, 0:8, 0:8]
    #show_grad2(mon_list, L, phase)
    for i in range(xx):
        action_i = bm.zeros(135+51)
        action_i[i+51+115] = 1
        action_i[30]=1

        tpcn.next_predict(no_input, action_i)
        #print(bm.mean(dict2array([mon])))
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
        # 绘制生成的图像
        axes[1, i].imshow(bm.transpose(ob_i[0], (1, 2, 0)))
        axes[1, i].set_title(f'G {i}')
        axes[1, i].axis('off')

    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    plt.savefig('s_first_3.png')

def draw_feature(tpcn, batch_size, test_dataloader,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 12

    # 循环绘制原始图像和生成的图像
    diff=0
    neuron_firing=bm.zeros(68*26)
    for i, (input, labely,labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        if int(labely[0])%2==1 or int(labelx[0])%2==1:
            continue
        labelx = int(labelx[0])//2 #[0-67]
        labely=int(labely[0])//2 #[0-25]
        action = np.zeros(68*26)
        frag=labelx+68*labely
        #print(frag)
        action[frag]=1
        tpcn.next_predict(no_input, action)
        #print(tpcn.s[1][0].sum())
        neuron_firing[frag]=tpcn.s[layer_index][0][neuron_index]
        #print(tpcn.s[layer_index])
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
        diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
        #print(diff_i.mean())
        diff=diff+diff_i.mean()

    diff=diff/68/26
    print(diff)
    #neuron_fft=np.fft.fft(neuron_firing)/68/26
    #plt.plot(range(len(neuron_fft)),abs(neuron_fft))
    plt.plot(range(len(neuron_firing)),neuron_firing)
    neuron_mean=neuron_firing.mean()
    neuron_var=neuron_firing.var()
    plt.title(f'Mean: {neuron_mean:.6f} , Var: {neuron_var:.6f}')
    # 调整子图布局
    #plt.tight_layout()

    # 显示画布
    # plt.show()
    plt.savefig('pic_68_26/firing_'+str(layer_index)+'_'+str(neuron_index)+'_fft.png')
    
def draw_feature_sort(tpcn, batch_size, test_dataloader,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 12

    # 循环绘制原始图像和生成的图像
    diff=0
    neuron_firing=bm.zeros((4,16,68*26))
    fig, axes = plt.subplots(4, 16, figsize=(40, 10))
    for i, (input, labely,labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        if int(labely[0])%2==1 or int(labelx[0])%2==1:
            continue
        labelx = int(labelx[0])//2 #[0-67]
        labely=int(labely[0])//2 #[0-25]
        action = np.zeros(68*26)
        frag=labelx+68*labely
        #print(frag)
        action[frag]=1
        tpcn.next_predict(no_input, action)
        #print(tpcn.s[1][0].sum())
        for x in range(3):
            for y in range(10):
                neuron_firing[x,y,frag]=-1*tpcn.e[0][0,x*60*80+y*480]
        #print(tpcn.s[layer_index])
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
        diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
        #print(diff_i.mean())
        diff=diff+diff_i.mean()

    diff=diff/68/26
    print(diff)
    #neuron_fft=np.fft.fft(neuron_firing)/68/26
    #plt.plot(range(len(neuron_fft)),abs(neuron_fft))
    #neuron_firing=np.sort(neuron_firing,axis=2)
    for x in range(4):
        for y in range(16):
            axes[x,y].plot(range(68*26),neuron_firing[x,y,:])
    #plt.plot(range(len(neuron_firing)),neuron_firing)
    #neuron_mean=neuron_firing.mean()
    #neuron_var=neuron_firing.var()
    #plt.title(f'Mean: {neuron_mean:.6f} , Var: {neuron_var:.6f}')
    # 调整子图布局
    #plt.tight_layout()
    
    # 显示画布
    # plt.show()
    plt.savefig('pic_68_26/sortnon_list'+str(layer_index)+'_'+str(neuron_index)+'_e.png')
    
def draw_theta(tpcn, batch_size, test_dataloader,layer_index,neuron_index):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 12

    # 循环绘制原始图像和生成的图像
    diff=0
    tpcn.s[1][0,:]=bm.zeros(4000)
    tpcn.s[1][0,0]=1
    #print(tpcn.L)
    for i in range(0, 0, -1):
        tpcn.s[i].value = tpcn.f(tpcn.s[i + 1]) @ tpcn.theta[i]
    #print(tpcn.s[1][0])
    ob_i=(tpcn.f(tpcn.s[1]) @ tpcn.theta[0]).reshape(-1,3,60,80)
    print(ob_i.shape)
    #theta_list=tpcn.theta[layer_index][neuron_index].reshape(3,60,80)*100
    #for i in range(2):
    #    ob_i[0][i]=bm.zeros((60,80))
    ob_i=bm.transpose(ob_i[0], (1, 2, 0))
    #print(ob_i.shape)
    #print(ob_i[:,:,1])
    
    ob_i_green=ob_i[:,:,1]
    print(ob_i_green)
    #ob_i_green=bm.zeros((60,80))
    plt.imshow(ob_i_green,cmap='Greens',vmin=0,vmax=1)
    plt.savefig('pic_68_26/theta_'+str(1)+'_'+str(0)+'_green.png')
    #for i, (input, labely,labelx) in enumerate(test_dataloader):
    #    # 绘制原始图像
    #    if int(labely[0])%2==1 or int(labelx[0])%2==1:
    #        continue
    #    labelx = int(labelx[0])//2 #[0-67]
    #    labely=int(labely[0])//2 #[0-25]
    #    action = np.zeros(68*26)
    #    frag=labelx+68*labely
    #    #print(frag)
    #    action[frag]=1
    #    tpcn.next_predict(no_input, action)
    #    neuron_firing[frag]=tpcn.s[layer_index][0][neuron_index]
    #    #print(tpcn.s[layer_index])
    #    ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
    #    diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
    #    #print(diff_i.mean())
    #    diff=diff+diff_i.mean()
#
    #diff=diff/68/26
    #print(diff)
    ##neuron_fft=np.fft.fft(neuron_firing)/68/26
    ##plt.plot(range(len(neuron_fft)),abs(neuron_fft))
    #plt.plot(range(len(neuron_firing)),neuron_firing)
    #neuron_mean=neuron_firing.mean()
    #neuron_var=neuron_firing.var()
    #plt.title(f'Mean: {neuron_mean:.6f} , Var: {neuron_var:.6f}')
    ## 调整子图布局
    ##plt.tight_layout()
#
    ## 显示画布
    ## plt.show()
    #plt.savefig('pic_68_26/firing_'+str(layer_index)+'_'+str(neuron_index)+'_fft.png')

if __name__ == '__main__':
    config = 'Lab_layer4'  ### 'CIFAR10'  'Fashionmnist'
    State = 1
    batch_size = 1
    seq_len = 100
    ob_size = 3 * 80 * 60
    action_size = 1 * seq_len
    # training_Epoch
    Epoch = 40
    if torch.cuda.is_available():
        bm.set_platform('gpu')
        print('use gpu')
    else:
        print('use cpu')
    #bm.set_platform('cpu')
    # hee para
    neuron_size, l_duration, duration, eta, dt, f, noise, \
    model_name, using_epoch, training_data, normalize, inv_normalize, _transforms = load_config(config)
    # load_data
    data_lenth = len(training_data)
    train_dataloader = DataLoader(training_data,
                                  batch_size=batch_size,
                                  shuffle=True)
    for input, labely,labelx in train_dataloader:
        image_shape = input.shape[1:]
        image_size = 1
        print(labelx)
        for j in image_shape:
            image_size *= j
        break
    print('image_shape = ', image_shape)
    print('image_size = ', image_size)
    print('data_length = ', data_lenth)
    print('simulation_length = ', int(duration / dt))
    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
    # model
    neuron_size[0] = ob_size
    L = len(neuron_size) - 1
    print('pcn_ob_size = ', neuron_size[0])
    tpcn = tPCN(neuron_size,
                action_size,
                eta,
                l_duration,
                duration,
                f,
                noise=noise,
                dt=dt)
    phase = make_phase(l_duration, duration, dt, L)

    if State == 0:
        seq_len=2000
        if os.path.exists('./modelcif/' + model_name + '_' + str(using_epoch) +
                          '_big.bp'):
            states = bp.checkpoints.load_pytree('./modelcif/' + model_name + '_' +
                                                str(using_epoch) + '_big.bp')
            tpcn.load_state_dict(states)

        for epoch in range(Epoch):
            p = 0
            for input, label in train_dataloader:
                ob_seq, action_seq = seq_generate(input, seq_len)
                tpcn.init_neuron(batch_size)
                mon_list = []
                for i in range(seq_len):
                    tpcn.next_predict(ob_seq[i], action_seq[i])
                    mon_list.append(tpcn.run())
                
                p += batch_size / data_lenth
                print(epoch+28, ' : ', int(p * 100), '%, grad=',
                       bm.mean(dict2array(mon_list)))
                if p >= 0.01:
                     break
                # print('grad=', bm.as_numpy(bm.mean(dict2array(mon_list), axis=(0, 2))))
                # show_grad(mon_list, L + 1, phase)
                # show_grad2(mon_list, L, phase)
            bp.checkpoints.save_pytree(
                './modelcif/' + model_name + '_' + str(epoch+28) + '_big.bp',
                tpcn.state_dict())

    if State == 1:
        #seq_len = 10

        seq_len = 100
        Epoch = 30
        count = 0
        #ob_seq, action_seq = seq_generate(input, seq_len)
        tpcn.init_neuron(batch_size)
        if os.path.exists('model_final/lab/Lab_layer4_net_9_100_700.bp'):
            states = bp.checkpoints.load_pytree('model_final/lab/Lab_layer4_net_9_100_700.bp')
            tpcn.load_state_dict(states)
        loss=[]
        for epoch in range(Epoch):
            tpcn.eta=tpcn.eta*0.95
            tpcn.eta_s=tpcn.eta_s*0.95
            mon_list = []
            flag=False
            for input, labely,labelx in train_dataloader:
                #print(labely,labelx)
                if int(labely[0])%10!=2 or int(labelx[0])%7!=0:
                #    print(1)
                    continue
                action = np.zeros(action_size)
                #print(int(labelx[0])+135*int(labely[0]))
                action[int(labelx[0])//7+int(labely[0])//10*20] = 1
                #print(int(labelx[0])//2+68*int(labely[0])//2)
                tpcn.next_predict(input.reshape(-1, 3 * 60 * 80), action)
                mon = tpcn.run()
                mon_list.append(mon)
                #count += 1
                #print(count)
                #if count==20:
                #    break
                #print(count, '/', seq_len, ': grad =',
                #      bm.mean(dict2array([mon])))
                if math.isnan(bm.mean(dict2array([mon]))):
                    show_grad2(mon_list, L, phase)
                    flag=True
                    break

            
            #plt.plot(np.arange(48), tpcn.s[-1][0,16:])
            #plt.plot(np.arange(48), tpcn.s[-1][1,16:])
            #plt.plot(np.arange(48), tpcn.s[-1][3,16:])
            #plt.show()
            #print(int(labelx[0])//2+68*int(labely[0])//2)
            #print(count)
            grad=bm.mean(dict2array(mon_list))
            print(epoch, ' : ', epoch, '%, grad=',
                  grad)
            loss.append(grad)
            bp.checkpoints.save_pytree(
                './model_final/lab/' + model_name + '_' + str(epoch) + '_100_700.bp',
                tpcn.state_dict())
            
             #show_grad(mon_list, L + 1, phase)
            #show_grad2(mon_list, L, phase)
            if flag:
                break
        
        test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
        generate(tpcn, batch_size, test_dataloader)
        #plt.close()
        #plt.plot(range(Epoch),loss)
        #plt.savefig('pic_68_26/loss_68_26.png')
    #CIFAR10
    if State == 2:
        seq_len = 3000

        count = 0
        for input, label in train_dataloader:
            ob_seq, action_seq = seq_generate(input, seq_len)
            tpcn.init_neuron(batch_size)
            for i in range(seq_len):
                tpcn.next_predict(ob_seq[i], action_seq[i])
                mon = tpcn.run()
                count += 1
                print(count, '/', seq_len, ': grad =',
                      bm.mean(dict2array([mon])))

            bp.checkpoints.save_pytree(
                './modeltest/' + model_name + '_' + str(seq_len) + '_128_1024.bp',
                tpcn.state_dict())
            generate(tpcn, batch_size, input)
            break
            
    #test_4
    if State == 3:
        seq_len = 2000
        if os.path.exists('model_lab4/Lab_empty_net_99_68_26.bp'):
            states = bp.checkpoints.load_pytree('model_lab4/Lab_empty_net_99_68_26.bp')
            tpcn.load_state_dict(states)
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            #print(tpcn.s[-1][0,16:])
            count = 0
            generate(tpcn, batch_size, test_dataloader)
            plt.close()
            draw_s(tpcn, batch_size, test_dataloader)

    #feature
    if State == 4:
        seq_len = 2000
        if os.path.exists('model_lab4/Lab_empty_net_99_68_26.bp'):
            states = bp.checkpoints.load_pytree('model_lab4/Lab_empty_net_99_68_26.bp')
            tpcn.load_state_dict(states)
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            #print(tpcn.s[-1][0,16:])
            count = 0
            layer_index=0
            neuron_index=0

            draw_feature_sort(tpcn, batch_size, test_dataloader,layer_index,neuron_index)
    
    #loss
    if State==5:
        seq_len = 2000
        error=0
        if os.path.exists('./modeltest/' + model_name + '_'
                          + '3000_128_1024.bp'):
            states = bp.checkpoints.load_pytree('./modeltest/' + model_name + '_'
                          + '3000_128_1024.bp')
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1][0,16:])
            count = 0
            for input, label in train_dataloader:
                error_list=bm.zeros(15)
                for num in range(15):
                    ob_seq, action_seq = test_seq_generate(input, seq_len,num+1)
                    tpcn.init_neuron(batch_size)
                    #action0 = bm.zeros(16)
                    #action0[0] = 1
                    #tpcn.next_predict(input[:, :, 0:8, 0:8].reshape(-1, 3 * 8 * 8), action0)
                    #tpcn.test_init()
                    for i in range(seq_len):
                        tpcn.next_predict(ob_seq[i], action_seq[i])
                        mon = tpcn.test_run()
                        #print(count, '/', seq_len, ': grad =',
                        #      bm.mean(dict2array([mon])))
                    error_list[num]=generate(tpcn, batch_size, input)
                    plt.close()
                #print(error_list.shape)
                plt.plot(range(15),error_list)
                plt.savefig('./pictest/errors.png')
                break

    #theta
    if State == 6:
        seq_len = 2000
        if os.path.exists('model_lab4/Lab_empty_net_99_68_26.bp'):
            states = bp.checkpoints.load_pytree('model_lab4/Lab_empty_net_99_68_26.bp')
            tpcn.load_state_dict(states)
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            #print(tpcn.s[-1][0,16:])
            count = 0
            layer_index=0
            neuron_index=0

            draw_theta(tpcn, batch_size, test_dataloader,layer_index,neuron_index)
