#!/bin/python
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
import matplotlib
matplotlib.use('Agg')
from hybrid_beta_vae import Reshape, VAE
from decolle.utils import parse_args, train, test, accuracy, save_checkpoint, load_model_from_checkpoint, prepare_experiment, write_stats, cross_entropy_one_hot
from utils import save_checkpoint, load_model_from_checkpoint
import datetime, os, socket, tqdm
import numpy as np
import torch
from torch import nn
import importlib
from itertools import chain
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from decolle.utils import MultiOpt
from torchneuromorphic import transforms
from tqdm import tqdm
import math
import sys

epsilon = sys.float_info.epsilon

mapping = { 0 :'Hand Clapping'  ,
            1 :'Right Hand Wave',
            2 :'Left Hand Wave' ,
            3 :'Right Arm CW'   ,
            4 :'Right Arm CCW'  ,
            5 :'Left Arm CW'    ,
            6 :'Left Arm CCW'   ,
            7 :'Arm Roll'       ,
            8 :'Air Drums'      ,
            9 :'Air Guitar'     ,
            10:'Other'}

np.set_printoptions(precision=4)
args = parse_args('parameters/params_hybridvae_dvsgestures-guidedbeta-noaug.yml')

params, writer, dirs = prepare_experiment(name=__file__.split('/')[-1].split('.')[0], args = args)
log_dir = dirs['log_dir']
checkpoint_dir = dirs['checkpoint_dir']

starting_epoch = params['start_epoch']

args.resume_from = params['resume_from']#
if args.resume_from != 'None':
    checkpoint_dir = args.resume_from


verbose = args.verbose

os.environ["CUDA_VISIBLE_DEVICES"] = params['device'].split(':')[1]
device = args.device

## Load Data

dataset = importlib.import_module(params['dataset'])
try:
    create_data = dataset.create_data
except AttributeError:
    create_data = dataset.create_dataloader

train_dl, test_dl = create_data(
                              root='data/dvsgesture/dvs_gestures.hdf5',
                              chunk_size_train=params['chunk_size_train'],
                              chunk_size_test=params['chunk_size_test'],
                              batch_size=params['batch_size'],
                              dt=params['deltat'],
                              num_workers=params['num_dl_workers'],
                              return_meta=True,
                              time_shuffle=True)#True)

data_batch, target_batch, light_batch, user_batch = next(iter(train_dl))

data_batch = data_batch[target_batch[:,-1,:].argmax(1)!=10]

data_batch = torch.Tensor(data_batch).to(device)
target_batch = torch.Tensor(target_batch).to(device)

def generate_process_target(params, aug_epoch=0):
    tau1 = 1/(1-params['alpha'][-1])
    tau2 = 1/(1-params['beta'][-1] )
    t1 = transforms.ExpFilterEvents(tau=tau2, length = int(6*tau2), tpad=int(6*tau2), device='cuda' )    
    t2 = transforms.ExpFilterEvents(tau=tau1, length = int(6*tau1), tpad=int(6*tau1), device='cuda' )
    if aug_epoch<1:
        filter_data = transforms.Compose([t1, t2, transforms.Rescale(50.)])
    else:
        filter_data = transforms.Compose([t1, t2, transforms.Rescale(50.), transforms.Jitter()])

    def process_target(data, aug_epoch=0):
        l = data.shape[1]
        if aug_epoch > 0:
            jitter_data = transforms.Compose([filter_data, transforms.Jitter(xs=2,ys=2,th=5)])
            return jitter_data(data)[:,l]
        return filter_data(data)[:,l]
    
    return filter_data, process_target

filter_data, process_target = generate_process_target(params)



class Guide(nn.Module):
    def __init__(self, dimz, num_classes, excite, hidden_layers):
        super(Guide, self).__init__()
        if excite:
            input_size = num_classes
        else:
            input_size = dimz-num_classes
            
        output_size = num_classes

        self.num_classes = num_classes

        self.model = nn.Sequential(hidden_layers)
        
        # init model weights
        for l in self.model:
            if isinstance(l, nn.Linear):
                torch.nn.init.kaiming_uniform_(l.weight, nonlinearity='leaky_relu')
                
    def forward(self, x):
        #print('forward')
        i = 0
        for l in self.model:
            x = l(x)
            if isinstance(l, nn.Linear):
                if i == 0:
                    i+=1
                else:
                    soft_inp = x[0]
                    with torch.no_grad():
                        soft_inp_mean = torch.mean(soft_inp)
        return x, soft_inp, soft_inp_mean
        

    def excite_z(self,z):
        exc_z = torch.zeros((z.shape[0],self.num_classes))
    
        for i in range(z.shape[0]):
            exc_z[i] = z[i,:self.num_classes]#[t[i]]
        
        return exc_z

    def inhibit_z(self,z):
        inhib_z = torch.zeros((z.shape[0], z.shape[1]-self.num_classes))
    
        for i in range(z.shape[0]):
            inhib_z[i] = z[i,self.num_classes:]
        
        return inhib_z

def compute_mig(latzs, mus, logvars, n_samples=100):
    # compute the mutual information gap which is defined in Chen et. al Isolating Sources of Disentanglement in VAEs
    
    len_dataset, latent_dim = latzs.shape
    H_z = torch.zeros(latent_dim)
    
    # sample from p(x)
    samples_x = torch.randperm(len_dataset)[:n_samples]
    # sample from p(z|x)
    latzs = latzs.index_select(0, samples_x).view(latent_dim, samples)
    
    mini_batch_size = 10
    
    latzs = latzs.expand(len_dataset, latent_dim, n_samples)
    mus = mus.expand(len_dataset, latent_dim, n_samples)
    logvars = logvars.expand(len_dataset, latent_dim, n_samples)
    
    log_N = math.log(len_dataset)
    
    
    with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
            for k in range(0, n_samples, mini_batch_size):
                # log q(z_j|x) for n_samples
                idcs = slice(k, k + mini_batch_size)
                log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
                                                 mean[..., idcs],
                                                 log_var[..., idcs])
                # numerically stable log q(z_j) for n_samples:
                # log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
                # As we don't know q(z) we appoximate it with the monte carlo
                # expectation of q(z_j|x_n) over x. => fix a single z and look at
                # proba for every x to generate it. n_samples is not used here !
                log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False)
                # H(z_j) = E_{z_j}[- log q(z_j)]
                # mean over n_samples (i.e. dimesnion 1 because already summed over 0).
                H_z += (-log_q_z).sum(1)

                t.update(mini_batch_size)

    H_z /= n_samples

    # tb continued

    

def loss_fn(recon_x, x, mu, logvar, vae_beta = 4.0):
    llhood = torch.nn.functional.mse_loss(recon_x, x)
    #negKLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)/len(train_dl)
    #print(llhood,kld_loss)
    return llhood + vae_beta*kld_loss

def loss_fn_guided(recon_x, x, mu, logvar, excite_loss, inhib_loss, vae_beta=1):
    llhood = torch.nn.functional.mse_loss(recon_x, x)
    #negKLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())                                                                                                                                                                       
    kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)/len(train_dl)
    return llhood+vae_beta*kld_loss+(excite_loss-inhib_loss)

def train_step(s, x, net, opt_fn, loss_fn):
    net.train()
    y, mu, logvar = net(s)
    loss = loss_fn(y, x, mu, logvar,params['vae_beta'])
    loss.backward()
    opt_fn.step()
    opt_fn.zero_grad()
    return loss.data

def print_grads(excite, inhib, net):
    print("excite")
    for l in excite.model:
        if isinstance(l, nn.Linear) and l.weight.grad is not None:
            print(l.weight.grad.abs().mean())
        else:
            print(None)
    print("inhib")
    for l in inhib.model:
        if isinstance(l, nn.Linear) and l.weight.grad is not None:
            print(l.weight.grad.abs().mean())
        else:
            print(None)
    print("encoder head")
    for k,v in net.encoder_head.items():
        if isinstance(v, nn.Linear):
            print(v.weight.grad.abs().mean())
            
def calc_entropy(soft_inp, bins):
    soft_inp = soft_inp.detach().cpu().numpy()
    
    soft_hist, bin_edges = np.histogram(soft_inp,bins=bins,density=False)
    
    #print("histogram",soft_hist)
    
    entropy = 0
    for i in range(len(soft_hist)):
        entropy -= (soft_hist[i]/bins)*math.log(soft_hist[i]/bins+epsilon,2)
        
    return entropy
    
    
def batch_one_hot(targets, num_classes=10):
    one_hot = torch.zeros((targets.shape[0],num_classes))
    
    for i in range(targets.shape[0]):
        one_hot[i][targets[i]] = 1
        
    return one_hot
    

def train_step_guided(s, x, net, opt_fn, opt_excite, opt_inhib, loss_fn_guided, t, vae_beta=1): # t is the target class
    net.train()
    #excite.train()
    inhib.train()
    
    # VAE
    opt_fn.zero_grad()
    #opt_excite.zero_grad()
    hot_ts = batch_one_hot(t, num_classes=params['num_classes'])
    y, mu, logvar, clas = net(s)
    loss = loss_fn(y, x, mu, logvar, vae_beta)
    clas_loss = inhib_criterion(clas, hot_ts.cuda())*params['class_weight']
    vae_loss = loss
    loss += clas_loss
    loss.backward()
    opt_fn.step()
    #opt_excite.step()
    
    # excitation net
    opt_fn.zero_grad()
    #opt_inhib.zero_grad()
    z = net.reparameterize(mu,logvar).detach()
    inhib_z = inhib.inhibit_z(z)
    
    excite_output = inhib.model(inhib_z.cuda())
    
    loss = inhib_criterion(excite_output, hot_ts.cuda())*params['class_weight'] # now MultiLabelSoftMarginLoss
    excite_loss = loss
    loss.backward()
    opt_fn.step()
    #opt_inhib.step()
    
    # inhibition net
    opt_fn.zero_grad()
    opt_excite.zero_grad()
    mu, logvar = net.encode(s)
    z = net.reparameterize(mu,logvar)
    inhib_z = inhib.inhibit_z(z)
    #print(inhib_z.shape)
    inhib_output = inhib.model(inhib_z.cuda()) #, soft_inp, soft_mean = inhib.forward(inhib_z.cuda()) # inhib.model()
    soft_entropy = 0 #calc_entropy(soft_inp,bins=params['num_classes'])
    
    inhib_hot_ts = torch.empty_like(hot_ts).fill_(0.5)
    loss = inhib_criterion(inhib_output, inhib_hot_ts.cuda())*params['class_weight'] # now MultiLabelSoftMarginLoss
    inhib_loss = loss
    loss.backward()
    opt_fn.step()
    opt_excite.step()
 
    return vae_loss.data, excite_loss.data, inhib_loss.data, clas_loss.data, soft_entropy, 0 #soft_mean=0

def train(train_dl, net, opt_fn, loss_fn):
    net.train()
    loss_batch = []
    for x,t in tqdm(iter(train_dl)): 
        x_c = x.cuda()
        frames = process_target(x_c)
        loss_ = train_step(x_c,frames.cuda(), net, opt_fn, loss_fn)
        loss_batch.append(loss_.detach().cpu().numpy())   
    return np.mean(loss_batch)

def train_guided_aug(train_dl, net, opt_fn, opt_exc, opt_inhib, loss_fn):
    net.train()
    loss_batch = []
    excite_batch = []
    inhib_batch = []
    abs_batch = []
    entropy_batch = []
    mean_batch = []
    for i in range(params['num_augs']):
        for x,t,l,u in tqdm(iter(train_dl)):
            new_t = t[t[:,-1,:].argmax(1)!=10]
            new_t = new_t[:,-1,:].argmax(1)
            #print(new_t.shape)
            x = x[t[:,-1,:].argmax(1)!=10]
            x_c = x.cuda()
            frames = process_target(x_c,i-1)
            loss_, excite_loss_, inhib_loss_, loss_abs_, soft_entropy, soft_mean = train_step_guided(x_c,frames.cuda(), net, opt_fn,opt_exc,opt_inhib, loss_fn,new_t.long(),params['vae_beta'])
            loss_batch.append(loss_.detach().cpu().numpy())
            excite_batch.append(excite_loss_.detach().cpu().numpy())
            inhib_batch.append(inhib_loss_.detach().cpu().numpy())
            abs_batch.append(loss_abs_.detach().cpu().numpy())
            entropy_batch.append(soft_entropy)
            mean_batch.append(soft_mean)#.detach().cpu().numpy())
    return np.mean(loss_batch,dtype=np.float64), np.mean(excite_batch,dtype=np.float64), np.mean(inhib_batch,dtype=np.float64), np.mean(abs_batch,dtype=np.float64), np.mean(entropy_batch,dtype=np.float64), np.mean(mean_batch,dtype=np.float64)


def get_latent_train(dl, net, iterations=3):
    #all_d = []
    lats = []
    tgts = []
    usrs = []
    lights = []

    for i in range(iterations):
        for x,t,l,u in tqdm(iter(dl)):#,l,u in tqdm(iter(dl)): 
            new_t = t[t[:,-1,:].argmax(1)!=10]
            new_t = new_t[:,-1,:].argmax(1)
            x = x[t[:,-1,:].argmax(1)!=10]
            l = np.asarray(l)[t[:,-1,:].argmax(1)!=10]
            u = np.asarray(u)[t[:,-1,:].argmax(1)!=10]
            with torch.no_grad():
                mu, logvar = net.encode(x.cuda())
                lat = net.reparameterize(mu,logvar).detach().cpu().numpy()
                lats += lat.tolist()
                tgts += new_t.tolist()
                usrs += list(u)
                lights += list(l)
                #all_d += process_target(x).tolist()
    return np.array(lats), np.array(tgts), np.array(usrs), np.array(lights) #[:,-1,:].argmax(1)

def tsne_project(lats, tgts, usrs, lights, net, do_plot = True, use_user=False, use_light=False):
    from sklearn.manifold import TSNE
    lat_tsne = TSNE(n_components=2).fit_transform(lats)
    inhib_tsne = TSNE(n_components=2).fit_transform(net.inhibit_z(torch.from_numpy(lats)).numpy())
    exc_tsne = TSNE(n_components=2).fit_transform(net.excite_z(torch.from_numpy(lats)).numpy())
    if do_plot:
        fig = plt.figure(figsize=(16,10))
        fig2 = plt.figure(figsize=(16,10))
        fig3 = plt.figure(figsize=(16,10))
        fig4 = plt.figure(figsize=(16,10))
        fig5 = plt.figure(figsize=(16,10))
        ax = fig.add_subplot()
        ax2 = fig2.add_subplot()
        ax3 = fig3.add_subplot()
        ax4 = fig4.add_subplot()
        ax5 = fig5.add_subplot()
        usernames = list(set(usrs))
        lightnames = list(set(lights))
        for i in range(params['num_classes']):#1):
            idx = tgts==i
            ax.scatter(lat_tsne[idx,0],lat_tsne[idx,1], label = dataset.mapping[i])
            ax4.scatter(exc_tsne[idx,0],exc_tsne[idx,1], label = dataset.mapping[i])
            ax5.scatter(inhib_tsne[idx,0],inhib_tsne[idx,1], label = dataset.mapping[i])
        ax.legend()
        ax4.legend()
        ax5.legend()
        
        if use_user:
            for i in range(len(usernames)):
                idx = usrs==usernames[i] #tgts==i
                ax2.scatter(inhib_tsne[idx,0],inhib_tsne[idx,1], label = usernames[i])#training_set.mapping[i])
            ax2.legend()
        
        if use_light:
            for i in range(len(lightnames)):
                idx = lights==lightnames[i] #tgts==i
                ax3.scatter(inhib_tsne[idx,0],inhib_tsne[idx,1], label = lightnames[i])#training_set.mapping[i])
            ax3.legend()
        
        return lat_tsne, fig, fig2, fig3, fig4, fig5
    else:
        return lat_tsne

def activity_regularization(u):
    loss_tv = 0
    for i,ui in enumerate(u):
        uflat = u[i].reshape(u[i].shape[0],-1)
        reg1_loss = 1e-4*((torch.relu(uflat+.01)*mask)).mean()
        reg2_loss = 6e-6*relu((mask*(.1-torch.sigmoid(uflat))).mean())
        loss_tv += reg1_loss + reg2_loss
    return loss_tv

def eval_accuracy(lats, tgts, is_excite, excite, net):
    correct_count, all_count = 0, 0
    if is_excite:
        zs = excite.excite_z(torch.from_numpy(lats))
    else:
        zs = excite.inhibit_z(torch.from_numpy(lats))
    net.eval()
    for i in range(len(tgts)):
        net.eval()
        with torch.no_grad():
            logps = net(torch.unsqueeze(zs[i],0).to(device)) # was net.model(...)
        ps = torch.exp(logps.cuda())
        probab = list(ps.cpu().numpy()[0])
        pred_label = probab.index(max(probab))
        true_label = tgts[i]#.numpy()[i]
        if(true_label == pred_label):
            correct_count += 1
        all_count += 1
        
    return correct_count/all_count


def latent_traversal(lats,tgts,net,clas,n_plots=10):
    # do a latent traversal of a gesture (say right hand wave or something) and see if it produces waves
    # get first instance of a right hand wave latent space. Then change value of corresponding latent variable, should produce different waves
    num_classes = params['num_classes'] # params is a global variable right? yes, yes it is
    
    # to determine disentanglement and ability to traverse latent space,
    # try to set the attribute variables to minimum
    # then traverse along one of them and see if it transitions to the attribute
    # that is desired
    
    lat = lats[tgts==clas][0] # baseline latent example of class we will traverse along
    
    # get min values of relevant latent dimensions
    
    #for i in range(num_classes):
    #    lat[i] = min(lats.T[i])
    
    #get min and max values of latent dimension 1
    min_lat = min(lats.T[clas])
    max_lat = max(lats.T[clas])
    
    trav_space = (abs(min_lat)+max_lat)
    
    fig, axs = plt.subplots(1,n_plots,figsize=(16,10))
    
    for i in range(n_plots):
        lat[clas] = max_lat-(trav_space/n_plots)*i
        
        with torch.no_grad():
            lat = torch.tensor(lat,dtype=torch.float).cuda()
            decoded = net.decode(lat).cpu()
            if i==0:
                first_decoded = decoded
            elif i==n_plots-1:
                last_decoded = decoded
        
        axs[i].imshow(decoded[0,0].T)
        
    return fig
    
    
def latent_traversal_switch(lats,tgts,net,clas1,clas2,n_plots=10):
    min_lat1 = min(lats.T[clas1])
    max_lat1 = max(lats.T[clas1])
    
    min_lat2 = min(lats.T[clas2])
    max_lat2 = max(lats.T[clas2])
    
    num_classes = params['num_classes']
    
    lat = lats[tgts==clas1][0]
    
    #for i in range(num_classes):
    #    lat[i] = min(lats.T[i])
        
    lat[clas1] = min_lat1 #max_lat1
    lat[clas2] = max_lat2 #min_lat1
    
    trav_space1 = (abs(min_lat1)+max_lat1)
    trav_space2 = (abs(min_lat2)+max_lat2)
    
    fig, axs = plt.subplots(1,n_plots+1,figsize=(16,10))
    
    for i in range(n_plots+1):
        if i>0:
            lat[clas1] = lat[clas1]+(trav_space1/n_plots)
            lat[clas2] = lat[clas2]-(trav_space2/n_plots)
        
        with torch.no_grad():
            lat = torch.tensor(lat,dtype=torch.float).cuda()
            decoded = net.decode(lat).cpu()
        
        axs[i].imshow(decoded[0,0].T)
        
    return fig

def latent_traversal_inhib(lats,tgts,net,clas1,n_plots=10):
    min_lat1 = min(lats.T[clas1])
    max_lat1 = max(lats.T[clas1])
    
    min_lats = [] # list of min value of all latent variables
    max_lats = [] # list of max value of all latent variables
    
    num_classes = params['num_classes']
    
    lat = lats[tgts==clas1][0]
    
    # minimize all latent dimension variables
    for i in range(params['dimz']):
        #lat[i] = min(lats.T[i])
        
        if i >= num_classes:
            min_lats.append(min(lats.T[i]))
            max_lats.append(max(lats.T[i]))
            
    min_lats = torch.from_numpy(np.asarray(min_lats))
    max_lats = torch.from_numpy(np.asarray(max_lats))
                
    # maximize one of the target class latents such as right hand wave
    lat[clas1] = min_lat1 #max_lat1
    
    fig, axs = plt.subplots(1,n_plots+1,figsize=(16,10))
    
    for i in range(n_plots+1):
        if i>0:
            lat[num_classes*i:(num_classes*i+(params['dimz']-num_classes)//num_classes)] = max_lats[num_classes*(i-1):(num_classes*(i-1)+(params['dimz']-num_classes)//num_classes)]
        
        with torch.no_grad():
            lat = torch.tensor(lat,dtype=torch.float).cuda()
            decoded = net.decode(lat).cpu()
        
        axs[i].imshow(decoded[0,0].T)
        
    return fig
    
    

#d, t = next(iter(train_dl))
input_shape = data_batch.shape[-3:]

#Backward compatibility
if 'dropout' not in params.keys():
    params['dropout'] = [.5]

## Create Model, Optimizer and Loss
net = VAE(input_shape=params['input_shape'], seq_len=params['chunk_size_train'], dimz=params['dimz'], encoder_params=params).cuda()
from decolle.init_functions import init_LSUV
init_LSUV(net.encoder,data_batch.cuda())

if params['is_guided']:
    from collections import OrderedDict
    
    layer_size = 100
    layer_size2 =100 #300
    layer_size3 = 100 #400
    
    excite_layers = OrderedDict([
    ('lin1', nn.Linear(params['num_classes'],layer_size)),
    ('norm1', nn.BatchNorm1d(layer_size)),
    ('relu1', nn.LeakyReLU(negative_slope=0.2,inplace=True)),
    #('lin2', nn.Linear(layer_size,layer_size)),
    #('norm2', nn.BatchNorm1d(layer_size)),
    #('relu2', nn.LeakyReLU(negative_slope=0.2,inplace=True)),
    #('droput',nn.Dropout(0.05)),
    ('lin3', nn.Linear(layer_size, params['num_classes']))#,
    #('soft', nn.LogSoftmax(dim=1))#nn.LogSoftmax(dim=1))
])

    inhib_layers = OrderedDict([
        ('lin1', nn.Linear(params['dimz']-params['num_classes'],layer_size)),
        ('norm1', nn.BatchNorm1d(layer_size)),
        ('relu1', nn.LeakyReLU(negative_slope=0.2,inplace=True)),
        ('lin2', nn.Linear(layer_size,layer_size2)),
        ('norm2', nn.BatchNorm1d(layer_size2)),
        ('relu2', nn.LeakyReLU(negative_slope=0.2,inplace=True)),
        #('lin3', nn.Linear(layer_size2,layer_size3)),
        #('norm3', nn.BatchNorm1d(layer_size3)),
        #('relu3', nn.LeakyReLU(negative_slope=0.2,inplace=True)),
        ('lin4', nn.Linear(layer_size3, params['num_classes']))
    ])
    excite = Guide(params['dimz'],params['num_classes'],True,excite_layers).cuda()

    inhib = Guide(params['dimz'],params['num_classes'],False,inhib_layers).cuda()
    
    exc_criterion = nn.MultiLabelSoftMarginLoss(reduction='sum') #nn.CrossEntropyLoss() #nn.NLLLoss()
    
    inhib_criterion = nn.MultiLabelSoftMarginLoss(reduction='sum') #nn.CrossEntropyLoss() #nn.NLLLoss()
    
    opt_excititory = torch.optim.Adam(net.cls_sq.parameters(), lr=params['learning_rate'][2])#torch.optim.Adam(chain(*[net.encoder.get_trainable_parameters(),excite.model.parameters()]), lr=params['learning_rate'][2])
    opt_inhibitory = torch.optim.Adam(inhib.model.parameters(), lr=params['learning_rate'][3])#torch.optim.Adam(chain(*[net.encoder.get_trainable_parameters(),inhib.model.parameters()]), lr=params['learning_rate'][3])

# DECOLLE needs different learning rates
opt1 = torch.optim.Adamax(net.encoder.get_trainable_parameters(), lr=params['learning_rate'][0], betas=params['betas'], eps=1e-4)
opt2 = torch.optim.Adam(chain(*[net.encoder_head.parameters(),net.decoder.parameters()]), lr=params['learning_rate'][1])
if params['is_guided']:
    opt = MultiOpt(opt1,opt2, opt_excititory, opt_inhibitory)
else:
    opt = MultiOpt(opt1,opt2, excite_optimizer, inhib_optimizer)
#opt = torch.optim.Adamax(net.parameters(), lr = params['learning_rate'], eps=1e-6)

##Resume if necessary
if args.resume_from != 'None':
    print("Checkpoint directory " + checkpoint_dir)
    if not os.path.exists(checkpoint_dir) and not args.no_save:
        os.makedirs(checkpoint_dir)
    starting_epoch = load_model_from_checkpoint(checkpoint_dir, net, opt, excite, inhib)
    print('Learning rate = {}. Resumed from checkpoint'.format(opt.param_groups[-1]['lr']))
    
    orig = process_target(data_batch).detach().cpu().view(*[[-1]+params['output_shape']])[:,0:1]
    print(orig.shape)
    figure2 = plt.figure(99)
    plt.imshow(make_grid(orig, scale_each=True, normalize=True).transpose(0,2).numpy())
    if not args.no_save:
        writer.add_figure('original_train',figure2,global_step=1)

# Printing parameters
if args.verbose:
    print('Using the following parameters:')
    m = max(len(x) for x in params)
    for k, v in zip(params.keys(), params.values()):
        print('{}{} : {}'.format(k, ' ' * (m - len(k)), v))

print('\n------Starting training Hybrid VAE-------')



# --------TRAINING LOOP----------
num_classes = params['num_classes']
if not args.no_train:
    orig = process_target(data_batch).detach().cpu().view(*[[-1]+params['output_shape']])[:,0:1]
    figure2 = plt.figure(99)
    plt.imshow(make_grid(orig, scale_each=True, normalize=True).transpose(0,2).numpy())
    if not args.no_save:
        writer.add_figure('original_train',figure2,global_step=1)
    
    for e in tqdm(range(starting_epoch , params['num_epochs'] )):
        interval = e // params['lr_drop_interval']
        for i,opt_ in enumerate(opt):
            lr = opt.param_groups[-1]['lr']
            if interval > 0:
                opt_.param_groups[-1]['lr'] = np.array(params['learning_rate'][i]) / (interval * params['lr_drop_factor'])
                print('Changing learning rate from {} to {}'.format(lr, opt_.param_groups[-1]['lr']))
            else:
                opt_.param_groups[-1]['lr'] = np.array(params['learning_rate'][i])
                print('Changing learning rate from {} to {}'.format(lr, opt_.param_groups[-1]['lr']))

        if (e % params['test_interval']) == 0 and e!=0:
            print('---------------Epoch {}-------------'.format(e))
            if not args.no_save:
                print('---------Saving checkpoint---------')
                save_checkpoint(e, checkpoint_dir, net, opt, excite, inhib)

            #test here

            # tsne
            lats, tgts, usrs, lights = get_latent_train(train_dl, net, iterations=1)
            lats_test, tgts_test, usrs_test, lights_test = get_latent_train(test_dl, net, iterations=3)
            
            #latent space traversal
            fig = latent_traversal(lats, tgts, net, 1)
            
            fig_test = latent_traversal(lats_test, tgts_test, net, 1)
            
            fig_switch = latent_traversal_switch(lats, tgts, net, 1, 2)
            
            fig_inhib = latent_traversal_inhib(lats, tgts, net, 1)
            
            _, figure, fig2, fig3, fig6, fig8 = tsne_project(lats, tgts, usrs, lights, excite, use_user=True, use_light=True)
            _, figure2, fig4, fig5, fig7, fig9 = tsne_project(lats_test, tgts_test, usrs_test, lights_test, excite, use_user=True, use_light=True)
            
            if not args.no_save:
                writer.add_figure('latent_traversal',fig,global_step=e)
                writer.add_figure('latent_traversal_test',fig_test,global_step=e)
                writer.add_figure('latent_traversal_switch',fig_switch,global_step=e)
                writer.add_figure('latent_traversal_inhib',fig_inhib,global_step=e)
                writer.add_figure('tsne_train',figure,global_step=e)
                writer.add_figure('tsne_test',figure2,global_step=e)
                writer.add_figure('exc_train',fig6,global_step=e)
                writer.add_figure('exc_test',fig7,global_step=e)
                writer.add_figure('inhib_train',fig8,global_step=e)
                writer.add_figure('inhib_test',fig9,global_step=e)
                writer.add_figure('tsne_users_train',fig2,global_step=e)
                writer.add_figure('tsne_users_test',fig4,global_step=e)
                writer.add_figure('tsne_lights_train',fig3,global_step=e)
                writer.add_figure('tsne_lights_test',fig5,global_step=e)
            
            # excititory part of guided_vae, which should be disentangling and be a classifier for pseudo labeling?
                
            
                
            
            # reconstruction
            recon_batch, mu, logvar, clas = net(data_batch.cuda())
            recon_batch_c = recon_batch.detach().cpu()
            figure = plt.figure()
            img = recon_batch_c.view(*[[-1]+params['output_shape']])[:,0:1]
            plt.imshow(make_grid(img, scale_each=True, normalize=True).transpose(0,2).numpy())
            if not args.no_save:
                writer.add_figure('recon_train',figure,global_step=e)

        #train_here
        if params['is_guided'] and params['use_aug']:
            loss_, excite_loss_, inhib_loss_, loss_abs_, entropy, means = train_guided_aug(train_dl, net, opt, opt_excititory, opt_inhibitory, loss_fn_guided)
            if not args.no_save:
                writer.add_scalar('inhibitory_net_loss_1', excite_loss_, e)
                writer.add_scalar('inhibitory_net_loss_2', inhib_loss_, e)
                writer.add_scalar('clas_loss', loss_abs_, e)
            #writer.add_scalar('entropy', entropy, e)
            #writer.add_scalar('mean_values_softmax_inp', means, e)
            
            #for i in range()
            # tsne
            lats, tgts, usrs, lights = get_latent_train(train_dl, net, iterations=1)
            lats_test, tgts_test, usrs_test, lights_test = get_latent_train(test_dl, net, iterations=1)
            
            train_acc = eval_accuracy(lats, tgts, True, excite, net.cls_sq)
            test_acc = eval_accuracy(lats_test, tgts_test, True, excite, net.cls_sq)
            if not args.no_save:
                writer.add_scalar('vaeclas_net_train_acc', train_acc, e)
                writer.add_scalar('vaeclas_net_test_acc', test_acc, e)
            
            """
            train_acc = eval_accuracy(lats, tgts, True, excite)
            test_acc = eval_accuracy(lats_test, tgts_test, True, excite)
            writer.add_scalar('excititory_net_train_acc', train_acc, e)
            writer.add_scalar('excititory_net_test_acc', test_acc, e)
            
            inhib_train_acc = eval_accuracy(lats, tgts, False, inhib)
            inhib_test_acc = eval_accuracy(lats_test, tgts_test, False, inhib)
            writer.add_scalar('inhibitory_net_train_acc', inhib_train_acc, e)
            writer.add_scalar('inhibitory_net_test_acc', inhib_test_acc, e)
            """
            
        elif params['is_guided']:
            loss_, excite_loss_, inhib_loss_, loss_abs_ = train_guided(train_dl, net, opt, opt_excititory, opt_inhibitory, loss_fn_guided)
            if not args.no_save:
                writer.add_scalar('excititory_net_loss', excite_loss_, e)
                writer.add_scalar('inhibitory_net_loss', inhib_loss_, e)
                writer.add_scalar('abs_loss', loss_abs_, e)
            
            
            #for i in range()
            # tsne
            lats, tgts, usrs, lights = get_latent_train(train_dl, net, iterations=1)
            lats_test, tgts_test, usrs_test, lights_test = get_latent_train(test_dl, net, iterations=3)
            
            train_acc = eval_accuracy(lats, tgts, excite)
            test_acc = eval_accuracy(lats_test, tgts_test, excite)
            if not args.no_save:
                writer.add_scalar('excititory_net_train_acc', train_acc, e)
                writer.add_scalar('excititory_net_test_acc', test_acc, e)
        elif params['use_aug']:
            loss_ = train_aug(train_dl,net, opt, loss_fn)
        else:
            loss_ = train(train_dl,net, opt, loss_fn)
        if not args.no_save:
            writer.add_scalar('train_loss', loss_, e)
        
        plt.close('all') # close figures so they don't use too much memory...
        
        
        #writer.add_graph('graph', net, e)
          

