# 10.31 先尝试empty_room横轴16步图片的学习，
import math
import os
import random
from collections import defaultdict
from functools import partial

import brainpy as bp
import brainpy.math as bm
import jax
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from config import load_config
from tPCN import tPCN


def seq_generate(input, seq_len):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        h_i = random.choice(range(4))
        w_i = random.choice(range(4))
        ob_seq.append(input[:,:, h_i * 8:(h_i + 1) * 8,
                            w_i * 8:(w_i + 1) * 8].reshape(-1, 3 * 8 * 8))

        loc = bm.zeros(16)
        loc[h_i * 4 + w_i] = 1
        action_seq.append(loc)
    return ob_seq, action_seq


def test_seq_generate(input, seq_len, num):
    ob_seq = []
    action_seq = []
    for i in range(seq_len):
        w_i = random.choice(range(num))
        h_i = w_i // 4
        w_i -= h_i * 4
        ob_seq.append(input[:,:, h_i * 8:(h_i + 1) * 8,
                            w_i * 8:(w_i + 1) * 8].reshape(-1, 3 * 8 * 8))

        loc = bm.zeros(16)
        loc[h_i * 4 + w_i] = 1
        action_seq.append(loc)
    return ob_seq, action_seq


def make_phase(l_duration, duration, dt, L):
    phase = [0]
    l1 = duration // L * l_duration // dt
    l2 = (duration / L - duration // L * l_duration) // dt
    tmp = 0
    for i in range(L):
        tmp += l1
        phase.append(tmp)
        tmp += l2
        phase.append(tmp)
    phase = [int(x) for x in phase]
    #print(phase)
    return phase


def show_grad(mon_list, num, phase):
    seq_len = len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[j]
            grad = mon[f'avg_grad{i}']
            time = bm.linspace(0, 1, len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
                
            plt.xticks([])
    plt.show()


def show_grad2(mon_list, num, phase):
    seq_len = 12
    length = len(mon_list)
    plt.figure(figsize=(3 * seq_len, num * 2))
    for i in range(num):
        for j in range(seq_len):
            mon = mon_list[length - seq_len + j]
            grad = mon[f'avg_grad{i}'] + mon[f'avg_grad{i+1}']
            # print(grad)
            time = bm.linspace(0, 1, len(grad))
            # print(len(grad))
            plt.subplot(num, seq_len, i * seq_len + j + 1)
            for k in range(len(phase) - 1):
                c = 'blue'
                if k % 2 == 0:
                    c = 'red'
                plt.plot(time[phase[k]:phase[k + 1]],
                         grad[phase[k]:phase[k + 1]],
                         color=c)
                #print(time[phase[k]:phase[k + 1]])
            plt.xticks([])
    plt.savefig('error_450.png')


def dict2array(mon_list):
    seq = []
    seq_len = len(mon_list)
    # print(seq_len)
    for i in range(seq_len):
        mon = mon_list[i]
        seq_i = []
        for k, v in mon.items():
            seq_i.append(v)
        # for k, v in mon.items():
        #   for vv in v:
        #       for vvv in vv:
        #           seq_i.append(vvv)
        seq.append(seq_i)
    # print(seq)
    seq = bm.array(seq)
    return seq


def generate(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 50 * 100))
    # 设置画布大小和子图布局
    xx = 10
    fig, axes = plt.subplots(3, xx, figsize=(12, 4))

    # 循环绘制原始图像和生成的图像
    diff=0
    for i, (input, label) in enumerate(test_dataloader):
        # 绘制原始图像
        label = int(label[0])
        if label%12!=0:
            continue
        #print(frag)
        # print(fragx)
        label=(label//12)
        action = np.zeros(150)
        action[label]=1
        tpcn.next_predict(no_input, action)
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 50, 100)
        diff_i=abs(bm.transpose(ob_i[0], (1, 2, 0))-input[0].permute(1, 2, 0).numpy())
        #print(diff_i.mean())
        diff=diff+diff_i.mean()
        if (label%15)==0:
            frag=label//15
            #print(frag)
            axes[0, frag].imshow(input[0].permute(1, 2, 0))
            #print(input[0].permute(1, 2, 0))
            axes[0, frag].set_title(f'O {frag}')
            axes[0, frag].axis('off')
    
            # 绘制生成的图像
            axes[1, frag].imshow(bm.transpose(ob_i[0], (1, 2, 0)))
            axes[1, frag].set_title(f'G {frag}')
            axes[1, frag].axis('off')
            
            # 绘制diff
            #print(diff.mean())
            axes[2, frag].imshow(diff_i)
            #print(diff_i.mean())
            axes[2, frag].set_title(f'D {frag}')
            axes[2, frag].axis('off')
    diff=diff/150
    print(diff)
    fig.suptitle(f'Average Difference: {diff:.6f}')
    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    # plt.show()
    plt.savefig('pic_final/street/layer4/150_500.png')

def draw_s(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 50 * 100))
    # 设置画布大小和子图布局
    xx = 12
    fig, axes = plt.subplots(7, xx, figsize=(20, 8))

    # 循环绘制原始图像和生成的图像
    for i, (input, labelx,labely) in enumerate(test_dataloader):
        # 绘制原始图像
        labelx = int(labelx[0])
        labely=int(labely[0])
        #print(frag)
        # print(fragx)
        if (labelx % 6)==0 and (labely%25)==0:
            fragx=labelx//6
            fragy=labely//25
            frag=fragx+3*fragy
            action = np.zeros(118)
            action[labelx]=1
            action[labely+18]=1
            tpcn.next_predict(no_input, action)
            axes[0,frag].plot(range(len(tpcn.e[0][0])),input.reshape(-1, 3 * 50 * 100)[0].numpy())
            axes[1,frag].plot(range(len(tpcn.e[0][0])),-1 * tpcn.e[0][0])
            axes[2,frag].plot(range(len(tpcn.e[0][0])),input[0].numpy().flatten()+tpcn.e[0][0])
            print(abs(input[0].numpy().flatten()+tpcn.e[0][0]).mean())
            #print(len(tpcn.e[0]))
            #print(-1 * tpcn.e[0])
            for s_i in range(4):
                axes[s_i+3,frag].plot(range(len(tpcn.s[s_i+1][0])),tpcn.s[s_i+1][0])
            #print(tpcn.s[4][0])

    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    # plt.show()
    plt.savefig('result_street/s/1800_s_xy_130.png')
    

def generate_multi(tpcn, batch_size, test_dataloader):
    no_input = bm.zeros((batch_size, 3 * 60 * 80))
    # 设置画布大小和子图布局
    xx = 16
    lenn = 1000
    fig, axes = plt.subplots(2, xx, figsize=(12, 4))

    # 循环绘制原始图像和生成的图像
    for i, (input, labely, labelx) in enumerate(test_dataloader):
        # 绘制原始图像
        flagx = int(labelx[0])
        flagy = int(labely[0])
        if flagx > 134 - 20 and flagy == 30:
            axes[0, flagx - 115].imshow(input[0].permute(1, 2, 0))
            axes[0, flagx - 115].set_title(f'O {flagx}')
            axes[0, flagx - 115].axis('off')

    # print(tpcn.s[3][:, 0:16])
    # print(tpcn.s[3][:, 16:32])
    # print(tpcn.s[2][:, 16:22])
    # print(tpcn.s[1][:, 16:22])
    # print(tpcn.e[0][:, 16:26])
    action0 = bm.zeros(135 + 51)
    action0[30] = 1
    action0[51 + 134] = 1

    for input, labely, labelx in test_dataloader:
        if int(labely[0]) == 30 and int(labelx[0]) == 134:
            tpcn.next_predict(input.reshape(-1, 3 * 60 * 80), action0)
            tpcn.test_init()
            count = 0
            mon_list = []
            for i in range(lenn):
                tpcn.next_predict(input.reshape(-1, 3 * 60 * 80),
                                  action0)
                mon = tpcn.test_run()
                mon_list.append(mon)
                count += 1
                print(count, '/', lenn, ': grad =',
                          bm.mean(dict2array([mon])))
            break
    # ans[:, :, 0:8, 0:8] = input[:, :, 0:8, 0:8]
    # show_grad2(mon_list, L, phase)
    for i in range(xx):
        action_i = bm.zeros(135 + 51)
        action_i[i + 51 + 115] = 1
        action_i[30] = 1

        tpcn.next_predict(no_input, action_i)
        # print(bm.mean(dict2array([mon])))
        ob_i = (-1 * tpcn.e[0]).reshape(-1, 3, 60, 80)
        # 绘制生成的图像
        axes[1, i].imshow(bm.transpose(ob_i[0], (1, 2, 0)))
        axes[1, i].set_title(f'G {i}')
        axes[1, i].axis('off')

    # 调整子图布局
    plt.tight_layout()

    # 显示画布
    plt.savefig('s_first_3.png')


if __name__ == '__main__':
    config = 'street_layer4'  # ## 'CIFAR10'  'Fashionmnist'
    State = 1
    batch_size = 1
    seq_len = 150
    ob_size = 3 * 100 * 50
    action_size = 1 * seq_len
    # training_Epoch
    Epoch = 50
    if torch.cuda.is_available():
        bm.set_platform('gpu')
        print('use gpu')
    else:
        print('use cpu')
    #bm.set_platform('cpu')
    # hee para
    neuron_size, l_duration, duration, eta, dt, f, noise, \
    model_name, using_epoch, training_data, normalize, inv_normalize, _transforms = load_config(config)
    # load_data
    data_lenth = len(training_data)
    train_dataloader = DataLoader(training_data,
                                  batch_size=batch_size,
                                  shuffle=True)
    for input, label in train_dataloader:
        image_shape = input.shape[1:]
        image_size = 1
        print(label)
        for j in image_shape:
            print(j)
            image_size *= j
        break
    print('image_shape = ', image_shape)
    print('image_size = ', image_size)
    print('data_length = ', data_lenth)
    print('simulation_length = ', int(duration / dt))
    # model
    neuron_size[0] = ob_size
    L = len(neuron_size) - 1
    print('pcn_ob_size = ', neuron_size[0])
    tpcn = tPCN(neuron_size,
                action_size,
                eta,
                l_duration,
                duration,
                f,
                noise=noise,
                dt=dt)
    phase = make_phase(l_duration, duration, dt, L)

    if State == 0:
        seq_len = 2000
        if os.path.exists('./modelcif/' + model_name + '_' + str(using_epoch) + 
                          '_big.bp'):
            states = bp.checkpoints.load_pytree('./modelcif/' + model_name + '_' + 
                                                str(using_epoch) + '_big.bp')
            tpcn.load_state_dict(states)

        for epoch in range(Epoch):
            p = 0
            for input, label in train_dataloader:
                ob_seq, action_seq = seq_generate(input, seq_len)
                tpcn.init_neuron(batch_size)
                mon_list = []
                for i in range(seq_len):
                    tpcn.next_predict(ob_seq[i], action_seq[i])
                    mon_list.append(tpcn.run())
                
                p += batch_size / data_lenth
                print(epoch + 28, ' : ', int(p * 100), '%, grad=',
                       bm.mean(dict2array(mon_list)))
                if p >= 0.01:
                    break
                # print('grad=', bm.as_numpy(bm.mean(dict2array(mon_list), axis=(0, 2))))
                # show_grad(mon_list, L + 1, phase)
                # show_grad2(mon_list, L, phase)
            bp.checkpoints.save_pytree(
                './modelcif/' + model_name + '_' + str(epoch + 28) + '_big.bp',
                tpcn.state_dict())

    if State == 1:
        # seq_len = 10

        seq_len = 150
        Epoch = 40
        count = 0
        # ob_seq, action_seq = seq_generate(input, seq_len)
        tpcn.init_neuron(batch_size)
        #if os.path.exists('model_final/street/street_layer4_79_450.bp'):
        #    states = bp.checkpoints.load_pytree('model_final/street/street_layer4_79_450.bp')
        #    tpcn.load_state_dict(states)
        loss = []
        for epoch in range(Epoch):
            tpcn.eta = tpcn.eta * 0.95
            tpcn.eta_s = tpcn.eta_s * 0.95
            mon_list = []
            flag = False
            for input, label in train_dataloader:
                #print(label)
                if label%12!=0:
                    continue
                action = np.zeros(action_size)
                # print(int(labelx[0])+135*int(labely[0]))
                action[label//12] = 1
                #print(action)
                #print(label//4)
                #print(int(label[0]))
                # print(int(labelx[0])//2+68*int(labely[0])//2)
                tpcn.next_predict(input.reshape(-1, 3 * 50 * 100), action)
                mon = tpcn.run()
                mon_list.append(mon)
                #count += 1
                #if count==20:
                #    break
                #print(count, '/', seq_len, ': grad =',
                #     bm.mean(dict2array([mon])))
                if math.isnan(bm.mean(dict2array([mon]))):
                    show_grad2(mon_list, L, phase)
                    flag = True
                    break
            
            # plt.plot(np.arange(48), tpcn.s[-1][0,16:])
            # plt.plot(np.arange(48), tpcn.s[-1][1,16:])
            # plt.plot(np.arange(48), tpcn.s[-1][3,16:])
            # plt.show()
            # print(int(labelx[0])//2+68*int(labely[0])//2)
            # print(count)
            grad = bm.mean(dict2array(mon_list))
            print(epoch, ' : ', epoch, '%, grad=',
                  grad)
            loss.append(grad)
            bp.checkpoints.save_pytree(
                './model_final/street/' + model_name + '_' + str(epoch) + '_150_200.bp',
                tpcn.state_dict())
            
            #show_grad(mon_list, L + 1, phase)
            #show_grad2(mon_list, L, phase)
            if flag:
                break
        
        test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
        generate(tpcn, batch_size, test_dataloader)
        #plt.close()
        #draw_s(tpcn, batch_size, test_dataloader)
        #plt.plot(range(Epoch),loss)
        #plt.savefig('result_street/loss/loss_layer1_xy_130.png')
        
    if State == 2:
        seq_len = 2000
        #tpcn.init_neuron(batch_size)
        if os.path.exists('model_street/street_xy_129_xy.bp'):
            states = bp.checkpoints.load_pytree('model_street/street_xy_129_xy.bp')
            #print(states[f'tPCN0']['b_0'],states[f'tPCN0']['theta_0'])
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1],tpcn.theta)
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            #print(tpcn.s[-1])
            count = 0
            plt.close()
            generate(tpcn, batch_size, test_dataloader)
    
    if State ==3:
        if os.path.exists('model_street/street_xy_79_xy.bp'):
            states = bp.checkpoints.load_pytree('model_street/street_xy_79_xy.bp')
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1])
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            draw_s(tpcn, batch_size, test_dataloader)

    if State ==4:
        if os.path.exists('model_street/street_xy_79_xy.bp'):
            states = bp.checkpoints.load_pytree('model_street/street_xy_79_xy.bp')
            tpcn.load_state_dict(states)
            #print(tpcn.s[-1])
            test_dataloader = DataLoader(training_data,
                                     batch_size=batch_size,
                                     shuffle=False)
            no_input = bm.zeros((batch_size, 3 * 50 * 100))
            for i, (input, labelx,labely) in enumerate(test_dataloader):
                labelx=int(label[0])
                labely=int(label[0])
                
            