import os
from collections import defaultdict
from functools import partial

import brainpy as bp
import brainpy.math as bm
import jax
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from config import load_config


class tPCN(bp.DynamicalSystem):

    def __init__(self,
                 neuron_size,
                 action_size,
                 eta,
                 l_duration,
                 duration,
                 f,
                 noise=0,
                 dt=bm.dt):
        super(tPCN, self).__init__(name=None)
        # hyper_parameters
        self.neuron_size = neuron_size
        self.L = len(self.neuron_size) - 1
        self.eta = eta*10
        self.eta_s = eta
        self.duration = duration
        self.f = f
        self.noise = noise
        self.dt = dt
        self.lc = [0.2, 10, 50, 500]
        self.ls=[1,1,0.5,1]

        # generate_parameters
        self.phase = bm.linspace(0, duration, self.L + 1)

        self.par_num = 0
        self.action_size = action_size
        self.theta_time = self.duration // self.L * l_duration
        self.avg_grad = bm.Variable(bm.zeros(self.L + 1))
        batch_size = 1

        # parameters
        self.theta = bm.var_list()
        for i in range(self.L):
            self.theta.append(
                bm.Variable(
                    bm.random.normal(0,
                                     0.1,
                                     size=(neuron_size[i + 1],
                                           neuron_size[i]))))
            self.par_num += neuron_size[i + 1] * neuron_size[i]

        # self.b = bm.var_list()
        # for i in range(self.L + 1):
        #     self.b.append(bm.Variable(bm.zeros(neuron_size[i])))
        #     self.par_num += neuron_size[i]

        # neurons
        self.s = bm.var_list()
        self.e = bm.var_list()
        for i in range(self.L + 1):
            self.s.append(
                bm.Variable(bm.zeros((batch_size, self.neuron_size[i])),
                            batch_axis=0))
            self.e.append(
                bm.Variable(bm.zeros((batch_size, self.neuron_size[i])),
                            batch_axis=0))
        # momoery
        self.mem = bm.Variable(bm.zeros((batch_size, self.neuron_size[-1])),
                               batch_axis=0)

        # moniter
        self.mon = dict()
        for i in range(self.L):
            self.mon[f'grad{i}'] = []
        for i in range(self.L):
            self.mon[f's{i}'] = []
        # for i in range(self.L + 1):
        #     self.mon[f'e{i}'] = []

        # self.n_step_per_fun = [10, 20]

        # functions
        # self.compiled_s = []
        # for i in range(self.L):
        #     self.compiled_s.append(self._generate_update_s)
        # self.compiled_theta = []
        # for i in range(self.L):
        #     self.compiled_theta.append(self._generate_update_theta)

    @bm.cls_jit
    def run(self):
        # with jax.disable_jit():
        # mon = defaultdict(list)
        mon = dict()
        #print((self.duration / self.L - self.theta_time) // self.dt)
        for i_layer in range(self.L):

            res = bm.for_loop(partial(self._generate_update_theta, i_layer),
                              np.arange(self.theta_time // self.dt))
            
            #print(self.theta_time//self.dt)
            for k, v in res.items():
                if k not in mon:
                    mon[k] = v
                else:
                    mon[k] = bm.concat([mon[k], v], axis=0)

            res = bm.for_loop(
                partial(self._generate_update_s, i_layer),
                np.arange(
                    (self.duration / self.L - self.theta_time) // self.dt))
                
            for k, v in res.items():
                if k not in mon:
                    mon[k] = v
                else:
                    mon[k] = bm.concat([mon[k], v], axis=0)

        # for i, t in enumerate(np.arange(0, self.duration, self.dt)):
        #     res = self.s_theta[self.fun_index[i][0]](i=self.fun_index[i][1])
        #     for k, v in res.items():
        #         mon[k].append(v)
        return dict(mon)
        # for k, v in self.mon.items():
        #     self.mon[k] = bm.asarray(v)
    
    @bm.cls_jit
    def test_run(self):
        mon = dict()
        for i_layer in range(self.L):
            res = bm.for_loop(partial(self.inference_s, i_layer),
                              np.arange((self.duration // self.L - self.theta_time) // self.dt))
            for k, v in res.items():
                if k not in mon:
                    mon[k] = v
                else:
                    mon[k] = bm.concat([mon[k], v], axis=0)
        return dict(mon)

    @bm.cls_jit
    def test_init(self):
        # mon = defaultdict(list)
        #print((self.duration // self.L - self.theta_time) // self.dt)
        for i_layer in range(self.L - 1):
            res = bm.for_loop(partial(self.inference_s, i_layer),
                              np.arange((self.duration // self.L - self.theta_time) // self.dt))
        res = bm.for_loop(partial(self._generate_update_s, self.L - 1),
                          np.arange((self.duration // self.L - self.theta_time) // self.dt))
        # for k, v in res.items():
        #    mon[k].append(v)  # (54, ....)

        # return dict(mon)

    def inference_s(self, i, idx):
        batch_size = self.s[0].shape[0]
        self.e[i].value = self.s[i] - self.f(self.s[i + 1]) @ self.theta[i]
        if i == self.L - 1:
            self.e[-1].value = self.s[-1] - self.mem
            # self.e[-1][:, 0:self.action_size] = self.s[
            #     -1][:, 0:self.action_size] - self.mem[:, 0:self.action_size]
        else:
            self.e[i + 1].value = (self.s[i + 1] - 
                                   self.f(self.s[i + 2]) @ self.theta[i + 1])
        a = self.theta[i]  # (n_i+1,n_i)
        b = bm.vector_grad(self.f)(self.s[i + 1])  # (batch,n_i+1)
        c = jax.vmap(lambda b1: bm.expand_dims(b1, axis=1) * a)(
            b)  # (batch,n_i+1,n_i)
        tmp_noise = bm.random.normal(0, self.noise,
                                     (batch_size, self.neuron_size[i + 1]))
        ds=bm.einsum('bjk,bk->bj', c, self.e[i]) * self.dt * self.eta_s*2 + tmp_noise * bm.sqrt(self.dt)
        e_i1=self.s[i] - self.f(self.s[i + 1]+ds) @ self.theta[i]
        def body(ds,e_i1):
            ds=ds*0.1
            e_i1=self.s[i] - self.f(self.s[i + 1]+ds) @ self.theta[i]
            return ds, e_i1
        def cond(ds,e_i1):
            frag = (bm.mean(bm.square(e_i1)) - bm.mean(bm.square(self.e[i])) > 0)
            return frag
        ds, e_i1 = bm.while_loop(body_fun=body, cond_fun=cond, operands=(ds, e_i1))
        self.s[i + 1].value = self.s[i + 1] +  ds
        self.avg_grad[i] = bm.mean(bm.square(self.e[i]))
        self.avg_grad[i + 1] = bm.mean(bm.square(self.e[i + 1]))

        return {f'avg_grad{j}': self.avg_grad[j] for j in range(self.L + 1)}

    def _generate_update_s(self, i, idx):
        # @bm.jit
        # def f():
        batch_size = self.s[0].shape[0]
    
        self.e[i].value = self.s[i] - self.f(self.s[i + 1]) @ self.theta[i]
        if i == self.L - 1:
            self.e[-1].value = self.s[-1] - self.mem
            # self.e[-1][:, 0:self.action_size] = self.s[
            #     -1][:, 0:self.action_size] - self.mem[:, 0:self.action_size]
        else:
            self.e[i + 1].value = self.s[i + 1] - self.f(self.s[i + 2]) @ self.theta[i + 1]
        
        a = self.theta[i]  # (n_i+1,n_i)
        b = bm.vector_grad(self.f)(self.s[i + 1])  # (batch,n_i+1)
        c = jax.vmap(lambda b1: bm.expand_dims(b1, axis=1) * a)(
            b)  # (batch,n_i+1,n_i)
        tmp_noise = bm.random.normal(0, self.noise,
                                     (batch_size, self.neuron_size[i + 1]))
        ds=(-self.e[i + 1] + bm.einsum('bjk,bk->bj', c, self.e[i])
                      ) * self.dt * self.eta_s* self.ls[i] + tmp_noise * bm.sqrt(self.dt)
        e_i1=bm.mean(bm.square(self.s[i] - self.f(self.s[i + 1]+ds) @ self.theta[i]))
        if i == self.L - 1:
            e_i1=e_i1+bm.mean(bm.square(self.s[-1]+ds - self.mem))
        else:
            e_i1=e_i1+bm.mean(bm.square(self.s[i + 1]+ds - 
                                   self.f(self.s[i + 2]) @ self.theta[i + 1]))

        def body(ds,e_i1):
            ds=ds*0.1
            e_i1=bm.mean(bm.square(self.s[i] - self.f(self.s[i + 1]+ds) @ self.theta[i]))
            if i == self.L - 1:
                e_i1=e_i1+bm.mean(bm.square(self.s[-1]+ds - self.mem))
            else:
                e_i1=e_i1+bm.mean(bm.square(self.s[i + 1]+ds - 
                                       self.f(self.s[i + 2]) @ self.theta[i + 1]))
            return ds, e_i1
        def cond(ds,e_i1):
            frag = (0.9*e_i1 - bm.mean(bm.square(self.e[i]))-bm.mean(bm.square(self.e[i+1])) > 0)
            return frag
        ds, e_i1 = bm.while_loop(body_fun=body, cond_fun=cond, operands=(ds, e_i1))
        
        
        self.s[i + 1].value = self.s[
            i + 1] + ds

        self.avg_grad[i] = bm.mean(bm.square(self.e[i]))
        self.avg_grad[i + 1] = bm.mean(bm.square(self.e[i + 1]))

        return {f'avg_grad{j}': self.avg_grad[j] for j in range(self.L + 1)}

        # return {f's{i}': self.s[i + 1].value}

        # return f

    def _generate_update_theta(self, i, idx):
        self.e[i].value = self.s[i] - self.f(self.s[i + 1]) @ self.theta[i]
        
        dtheta = (bm.mean(
            bm.einsum('bn,bm->bnm', self.f(self.s[i + 1]),
                      self.e[i]), axis=0)) * self.dt * self.eta * self.lc[i]
        e_i1 = self.s[i] - self.f(self.s[i + 1]) @ (self.theta[i] + dtheta)

        def body(dtheta, e_i1):
            dtheta = dtheta * 0.1
            e_i1 = self.s[i] - self.f(self.s[i + 1]) @ (self.theta[i] + dtheta)
            return dtheta, e_i1
        
        def cond(dtheta, e_i1):
            frag = (bm.mean(bm.square(e_i1)) - bm.mean(bm.square(self.e[i])) > 0)
            return frag
        
        dtheta, e_i1 = bm.while_loop(body_fun=body, cond_fun=cond, operands=(dtheta, e_i1))
        self.theta[i] = self.theta[i] + dtheta

        # while bm.square(e_i1)-bm.mean(bm.square(self.e[i]))>0:
        #    corr=corr*0.1
        #    e_i1=self.s[i] - self.f(self.s[i + 1]) @(self.theta[i] + (bm.mean(
        #    bm.einsum('bn,bm->bnm', self.f(self.s[i + 1]),
        #              self.e[i]), axis=0)) * self.dt * self.eta*corr)
        #
        # self.theta[i].value=self.theta[i] + (bm.mean(
        #    bm.einsum('bn,bm->bnm', self.f(self.s[i + 1]),
        #              self.e[i]), axis=0)) * self.dt * self.eta*corr
        self.avg_grad[i] = bm.mean(bm.square(self.e[i]))
        return {f'avg_grad{j}': self.avg_grad[j] for j in range(self.L + 1)}

    # s,e均设为0, theta不变
    def init_neuron(self, batch_size):
        for i, num in enumerate(self.neuron_size):
            self.e[i].value = bm.zeros((batch_size, num))
            self.s[i].value = bm.zeros((batch_size, num))
            # print(num)
        self.mem.value = bm.zeros((batch_size, self.neuron_size[-1]))
        # self.b[-1][self.action_size:] = bm.random.normal(0, 1,size = (self.neuron_size[-1] - self.action_size))

    def next_predict(self, ob, action):
        self.mem[:, 0:self.action_size] = action
        self.mem[:, self.action_size:] = self.s[-1][:, self.action_size:]

        self.s[-1].value = self.mem
        # print(self.mem)
        # print(self.s[0].size)
        for i in range(self.L - 1, 0, -1):
            self.s[i].value = self.f(self.s[i + 1]) @ self.theta[i]
        # print(self.s[0].shape)
        self.s[0][:] = ob

        for i in range(self.L, 0, -1):
            self.e[i].value = self.e[i] * 0
        self.e[0].value = self.s[0] - self.f(self.s[1]) @ self.theta[0]
        self.avg_grad.value = bm.zeros(self.L + 1)
        self.avg_grad[0] = bm.mean(bm.square(self.e[0]))

    def __save_state__(self) -> dict:
        # raise NotImplementedError
        r = {f'theta_{i}': p for i, p in enumerate(self.theta)}
        r.update({f'b_{j}': m for j, m in enumerate(self.s[-1])})
        # r = {f'smax_{i}': p for i, p in enumerate(self.s[-1])}
        # for i, s in enumerate(self.b):
        # r[f'smax_{0}'] = self.s[-1]
        # r['rng'] = bm.random
        return r

    def __load_state__(self, state_dict: dict):
        # raise NotImplementedError
        #print(self.theta)
        for i, theta_i in enumerate(self.theta):
            theta_i.value = bm.asarray(state_dict[f'theta_{i}'])
            #print(theta_i)
        #print(self.s[-1])
        #for j, b in enumerate(self.s[-1]):
        #    b[0].value = bm.asarray(state_dict[f'b_{j}'])
        for j, b in enumerate(self.s[-1]):
            self.s[-1][j] = bm.asarray(state_dict[f'b_{j}'])
        #print(bm.asarray(state_dict[f'b_{0}']))
        #print(self.s[-1])

        # self.s[-1].value = bm.asarray(state_dict[f'smax_{0}'])
        # for i, s in enumerate(self.b):
        #     s.value = bm.asarray(state_dict[f'b_{i}'])
        # bm.random.value = state_dict['rng']
        return (), ()
