import numpy as np
import brainpy as bp
import brainpy.math as bm
import os

from model.simulation_eva import Net_sim, getdata
from model.intelligent_eva import M_2D, M_3D

from utils.training import training_sim, training_int, testing_int, label_int
from utils.visual import visual
from utils.mathematical_analysis import math_analysis



class OS_array():
    def __init__(self, args):
        self.args = args

        bm.set_platform(args["Platform"])
        if self.args["Experiment_id"] == 0:
            os.environ['CUDA_VISIBLE_DEVICES'] = self.args["cuda_id"]
            bm.set_environment(bm.training_mode)
            bm.set_dt(0.1)
            self.net = Net_sim(64)
            self.trainer = bp.DSRunner(self.net, monitors=['n.spike', 'n1.spike'], data_first_axis = 'T',  progress_bar=False)
            self.sim_T = 10   # s

        elif self.args["Experiment_id"] == 1 or self.args["Experiment_id"] == 2:
            id = self.args["Experiment_id"]
            Experiment_id = "Experiment_" + str(id)

            os.environ['CUDA_VISIBLE_DEVICES'] = self.args["cuda_id"]
            bm.set_environment(bm.training_mode)

            if id == 1:
                self.net = M_2D(tau=args[Experiment_id]["tau"], 
                                alpha=args[Experiment_id]["alpha"], beta=args[Experiment_id]["beta"],
                                T_dur=args[Experiment_id]["T_dur"], K=args[Experiment_id]["K"],
                                delay=args[Experiment_id]["delay"]
                                )
            else:
                self.net = M_3D(tau=args[Experiment_id]["tau"], 
                                alpha=args[Experiment_id]["alpha"], beta=args[Experiment_id]["beta"],
                                T_dur=args[Experiment_id]["T_dur"], K=args[Experiment_id]["K"],
                                delay=args[Experiment_id]["delay"]
                                )
            
            self.trainer = bp.DSRunner(self.net, monitors=['n.spike'], data_first_axis = 'T',  progress_bar=False)
            self.sim_T=args[Experiment_id]["sim_T"]
            self.path = args[Experiment_id]["path"]
            self.batch_size = args[Experiment_id]["batch_size"]
            self.I_max = args[Experiment_id]["I_max"]
            self.save_weight = args[Experiment_id]["save_weight"]
            self.save_label = args[Experiment_id]["save_label"]


    
    def run(self):
        # =============  Simulation similarity evaluation =================
        if self.args["Experiment_id"] == 0:
            input_MEA = getdata()
            steps = int(1 / bm.get_dt())
            simulation = bm.zeros((steps * (self.sim_T), 1, 8, 8))
            simulation[np.arange(self.sim_T) * steps, :, :, :] = input_MEA * 500
            simulation = simulation.reshape((-1,1,64))

            spikes = training_sim(BNNnet = self.net, inputdata = simulation, 
                                  trainer = self.trainer, epoch = 2)
            real_spikes = np.load("./data/organoid_data_"+ str(self.args["Experiment_0"]["group"]) + ".npy")
            
            if self.args["Experiment_0"]["visualization"]:
                visual(real_spikes, spikes)
            math_analysis(real_spikes, spikes)

        # ==================================================================

        # =============  Simulation intelligent evaluation =================
        elif self.args["Experiment_id"] == 1 or self.args["Experiment_id"] == 2:
            id = self.args["Experiment_id"]
            Experiment_id = "Experiment_" + str(id)

            if self.args[Experiment_id]["use_weight"]:
                state_dict = bp.checkpoints.load_pytree("./weight/"+ str(id+1) +"D_baseline.bp")
                self.net.load_state_dict(state_dict['net'])
                state_dict = None
                neuron_label = np.load("./weight/" + str(id+1) +"D_baseline.npy")

            else:
                training_int(BNNnet=self.net, during=self.sim_T, 
                            trainer=self.trainer, path=self.path, 
                            batch_size = self.batch_size,
                            I_max = self.I_max,
                            save = self.save_weight)
                
                neuron_label = label_int(
                            BNNnet=self.net, during=self.sim_T, 
                            trainer=self.trainer, path=self.path, 
                            I_max = self.I_max, save = self.save_label)
            res = testing_int(
                            BNNnet=self.net, during=self.sim_T, 
                            trainer=self.trainer, path=self.path, 
                            I_max = self.I_max,
                            neuron_label=neuron_label)
            print("acc: ", res)