import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import sys
sys.path.append('../')
from rnn.vae import VAE
from rnn.train import train_VAE
from rnn.datasets import Basic_dataset

dim_z = 3
dim_N= 512
n_runs = 1

data_eval_name ="EEG_train.npy"
data_name ="EEG_data_zscored.npy"


wandb=False
n_epochs = 1000
bs =10

out_dir = ""
data_path = "data/"
cuda = True
    


task_params = {"name":"EEG",
               "dur":50,
               "n_trials":50*bs}

data = np.float32(np.load(data_path+data_name))
data_eval = np.float32(np.load(data_path+data_eval_name))
task = Basic_dataset(task_params, data,data_eval)
dim_x = task.data.shape[1]

for _ in range(n_runs):
    # initialise encoder
    enc_params = {
        "init_kernel_sizes":[11,5,5,3],
        "strides":[1]*4,
        "padding":'valid',#'same' or 'valid'
        "padding_mode":'zeros', #'zeros' or 'circular' or 'reflect'
        "nonlinearity":'gelu',
        "dilations":[1]*4,
        "n_channels":[24,16,16,16],
        "init_scale":0.1,
        "n_hidden":64,
        "obs_grad":True
        }

    # initialise prior
    prior_params={
        "clipped":True,
        "train_noise_obs":True,
        "train_noise_prior":True,
        "train_noise_prior_t0":True,
        "init_noise_z":.1,
        "init_noise_z_t0":1,
        "init_noise_x":.1,
        "scalar_noise_z":"Cov",
        "scalar_noise_x":False,
        "scalar_noise_z_t0":"Cov",
        'identity_readout':False,
        'activation':"relu",
        "exp_par":True,
        "shared_tau":.9,
        "readout_rates":False,# "currents",
        "train_obs_bias":True,
        "train_obs_weights":True,
        "train_latent_bias":False,
        "train_neuron_bias":True,
        "orth":False,
        "m_norm":False,
        "weight_dist":"uniform",
        "weight_scaler":1,#/dim_N,
        'initial_state':'trainable'
        }

    training_params = {
        "lr":1e-3,
        "lr_end":1e-6,
        'opt_eps':1e-8,
        "CosineRestarts":0,
        "beta":0.5,
        "n_epochs":n_epochs,
        "regularisation":"none",
        "regularisation_params": [0,0],#ratio, lambda
        "annealing":False,
        "annealing_epochs":500,
        "grad_norm":0,
        "eval_epochs":25,
        "batch_size":bs,
        "cuda":cuda,
        'smoothing':20,
        'freq_cut_off':-1,
        "sim_obs_noise":0,
        "sim_latent_noise":1,
        "k":10,
        "importance_weighting":False,#"none",
        "MC_q":True,
        "dreg_q":"none",#"all, direct or none",
        "MC_p":True,
        "dreg_p":"none",#"all, direct or none",
        "resample":"systematic",
        "loss_f":"opt_VGTF",# "GTF" or "VGTF
        "L2_reg":0,
        "observation_likelihood":"Gauss",
        "bootstrap":False,
        "alpha":.25,
        "alpha_decay":.999,
        "alpha_method":"constant",
        "alpha_update_interval":5,
        "run_eval":True,
        "smooth_at_eval":True,
        "t_forward":0,
    }

    VAE_params = {
        "dim_x":dim_x, 
        "dim_z":dim_z,
        "dim_N":dim_N,
        "enc_architecture":"Inv_Obs",
        "enc_params":enc_params,
        "prior_architecture":"PLRNN",
        "prior_params":prior_params,
        "causal":True,
        }

    vae=VAE(VAE_params)

    train_VAE(
        vae,
        training_params,
        task,
        sync_wandb=wandb,
        out_dir=out_dir,
        fname=None
        )

