import os
import argparse
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import wandb
import sys
import numpy as np
sys.path.append('..')
from datasets.Pendulum import PendulumDataset
from datasets.LotkaVolterra import LotkaVolterraDataset
from datasets.NBody import NBodyDataset
from datasets.PixelPendulum import PixelPendulumDataset
from utilities.losses import *

class BaseDynamicsModule(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        if self.hparams.rec_loss_type == 'MSE':
            self.reconstruction_loss = mse_loss
        elif self.hparams.rec_loss_type == 'L1':
            self.reconstruction_loss = l1_loss
        elif self.hparams.rec_loss_type == 'BCE_LOGITS':
            self.reconstruction_loss = bce_with_logits_loss
        elif self.hparams.rec_loss_type == 'CNN_MSE':
            self.reconstruction_loss = cnn_vae_mse_loss
        else:
            raise Exception(f'Wrong loss type {self.hparams.rec_loss_type}')

        self.val_rec_loss_sizes = [1, 5, 10, 20]
        self.test_rec_loss_sizes = [1, 5, 10, 20, 50, 100, 200]

        if self.hparams.model_output_size not in self.val_rec_loss_sizes:
            self.val_rec_loss_sizes.append(self.hparams.model_output_size)
        self.train_ind = None

        self.coord_dim = self.get_coords_dim()
        self.num_factors = self.get_factors_dim()

    def get_coords_dim(self):
        if 'pixel_pendulum' in self.hparams.dataset:
            coord_dim = None
        elif 'pendulum' in self.hparams.dataset:
            coord_dim = 2
        elif 'lv' in self.hparams.dataset:
            coord_dim = 2
        elif '3body' in self.hparams.dataset:
            coord_dim = 6
        return coord_dim

    def get_factors_dim(self):
        if 'pixel' in self.hparams.dataset:
            factors_dim = 1
        elif 'pendulum' in self.hparams.dataset:
            factors_dim = 1
        elif 'lv' in self.hparams.dataset:
            factors_dim = 4
        elif '3body' in self.hparams.dataset:
            factors_dim = 4
        return factors_dim


    def prepare_data(self, dataset_dir=os.environ['DATA_DIR_PHYSICS']):
        print('Preparing data')
        self.datasets = {}
        dt = self.hparams.dataset_dt

        if self.hparams.dataset=='pendulum_var_length':
            train_name = f'pendulum_n_10000_steps_2000_dt_{dt}_len_1.0-1.5_angle_10-170_g_9.81.hd5'
            test_name1 = f'pendulum_n_1000_steps_2000_dt_{dt}_len_0.5-1.0_angle_10-170_g_9.81.hd5'
            test_name2 = f'pendulum_n_1000_steps_2000_dt_{dt}_len_1.5-2.0_angle_10-170_g_9.81.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)
            self.datasets['train'] = PendulumDataset(train_path, self.hparams.coordinates, self.hparams.noise_std, indexes=list(range(0,9000)))
            self.datasets['val'] = PendulumDataset(train_path, self.hparams.coordinates, self.hparams.noise_std, indexes=list(range(9000,10000)))
            self.datasets['test'] = [self.datasets['val'], 
                                     PendulumDataset(test_path1, self.hparams.coordinates, self.hparams.noise_std),
                                     PendulumDataset(test_path2, self.hparams.coordinates, self.hparams.noise_std)]
        if self.hparams.dataset=='pendulum-2':
            train_name = f'pendulum_n_10000_steps_2000_dt_{dt}_len_1.0-1.5_angle_10-170_g_9.81.hd5'
            test_name1 = f'pendulum_n_1000_steps_2000_dt_{dt}_len_0.90-1.00_angle_10-170_g_9.81.hd5'
            test_name2 = f'pendulum_n_1000_steps_2000_dt_{dt}_len_1.50-1.60_angle_10-170_g_9.81.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)
            
            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = PendulumDataset(train_path, self.hparams.coordinates, self.hparams.noise_std, indexes=self.train_ind)
            val_set = PendulumDataset(train_path, self.hparams.coordinates, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = PendulumDataset(train_path, self.hparams.coordinates, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val'] = val_set
            self.datasets['test'] = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     PendulumDataset(test_path1, self.hparams.coordinates, self.hparams.noise_std),
                                     PendulumDataset(test_path2, self.hparams.coordinates, self.hparams.noise_std)]
        elif self.hparams.dataset=='lv-1':
            train_name = f'lv_n_10000_steps_2000_dt_{dt}_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05.hd5'
            val_name = f'lv_n_1296_steps_2000_dt_{dt}_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05_away_0.00_0.01.hd5'
            test_name1 = f'lv_n_1296_steps_2000_dt_{dt}_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05_away_0.01_0.03.hd5'
            test_name2 = f'lv_n_1296_steps_2000_dt_{dt}_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05_away_0.03_0.05.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            val_path=os.path.join(dataset_dir, val_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)
            self.datasets['train'] = LotkaVolterraDataset(train_path, self.hparams.noise_std)
            self.datasets['val'] = LotkaVolterraDataset(val_path, self.hparams.noise_std)
            self.datasets['test'] = [self.datasets['val'], 
                                     LotkaVolterraDataset(test_path1, self.hparams.noise_std),
                                     LotkaVolterraDataset(test_path2, self.hparams.noise_std)]
        elif self.hparams.dataset=='lv-2':
            train_name = f'lv-r_n_10000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.10.hd5'
            test_name1 = f'lv-r_n_1000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.10_away_0.00_0.02.hd5'
            test_name2 = f'lv-r_n_1000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.10_away_0.02_0.04.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)

            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.train_ind)
            val_set   = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val']   = val_set
            self.datasets['test']  = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     LotkaVolterraDataset(test_path1, self.hparams.noise_std),
                                     LotkaVolterraDataset(test_path2, self.hparams.noise_std)
                                     ]
        elif self.hparams.dataset=='lv-3':
            train_name = f'lv-r_n_10000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05.hd5'
            test_name1 = f'lv-r_n_1000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05_away_0.00_0.01.hd5'
            test_name2 = f'lv-r_n_1000_steps_1000_dt_0.01_X0_5.0_3.0_means_2.0_1.0_4.0_1.0_delta_0.05_away_0.01_0.02.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)

            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.train_ind)
            val_set   = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = LotkaVolterraDataset(train_path, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val']   = val_set
            self.datasets['test']  = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     LotkaVolterraDataset(test_path1, self.hparams.noise_std),
                                     LotkaVolterraDataset(test_path2, self.hparams.noise_std)
                                     ]
        elif self.hparams.dataset=='3body-1':
            train_name = f'3body_n_10000_steps_1000_dt_{dt}_params_K1_1.00_K2_0.80-1.20_m1_0.80-1.20_m2_0.80-1.20_m3_0.80-1.20.hd5'
            val_name = f'3body_n_1296_steps_1000_dt_{dt}_params_K1_1.00_K2_0.80-1.20_m1_0.80-1.20_m2_0.80-1.20_m3_0.80-1.20_away_0.00-0.02.hd5'
            test_name = f'3body_n_1296_steps_1000_dt_{dt}_params_K1_1.00_K2_0.80-1.20_m1_0.80-1.20_m2_0.80-1.20_m3_0.80-1.20_away_0.02-0.10.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            val_path=os.path.join(dataset_dir, val_name)
            test_path=os.path.join(dataset_dir, test_name)
            self.datasets['train'] = NBodyDataset(train_path, self.hparams.noise_std)
            self.datasets['val'] = NBodyDataset(val_path, self.hparams.noise_std)
            self.datasets['test'] = [self.datasets['val'], 
                                     NBodyDataset(test_path, self.hparams.noise_std)]
        elif self.hparams.dataset=='3body-2':
            train_name = f'3body_n_10000_steps_1000_dt_0.01_params_K1_1.00_K2_1.00-1.10_m1_1.00-1.10_m2_1.00-1.10_m3_1.00-1.10.hd5'
            val_name =   f'3body_n_1296_steps_1000_dt_0.01_params_K1_1.00_K2_1.00-1.10_m1_1.00-1.10_m2_1.00-1.10_m3_1.00-1.10_away_0.00-0.01.hd5'
            test_name1 = f'3body_n_1296_steps_1000_dt_0.01_params_K1_1.00_K2_1.00-1.10_m1_1.00-1.10_m2_1.00-1.10_m3_1.00-1.10_away_0.01-0.05.hd5'
            test_name2=  f'3body_n_1296_steps_1000_dt_0.01_params_K1_1.00_K2_1.00-1.10_m1_1.00-1.10_m2_1.00-1.10_m3_1.00-1.10_away_0.05-0.10.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            val_path=os.path.join(dataset_dir, val_name)
            test_path1=os.path.join(dataset_dir, test_name1)
            test_path2=os.path.join(dataset_dir, test_name2)
            self.datasets['train'] = NBodyDataset(train_path, self.hparams.noise_std)
            self.datasets['val'] = NBodyDataset(val_path, self.hparams.noise_std)
            self.datasets['test'] = [self.datasets['val'],
                                     NBodyDataset(test_path1, self.hparams.noise_std), 
                                     NBodyDataset(test_path2, self.hparams.noise_std)]
        elif self.hparams.dataset=='3body-3':
            train_name = f'3body_n_10000_steps_1000_dt_0.01_params_K2_1.00_m1_1.00_m2_1.00_m3_1.00_delta_0.05.hd5'
            test_name1 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_1.00_m1_1.00_m2_1.00_m3_1.00_delta_0.05_away_0.00_0.01.hd5'
            test_name2 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_1.00_m1_1.00_m2_1.00_m3_1.00_delta_0.05_away_0.01_0.02.hd5'
            train_path = os.path.join(dataset_dir, train_name)
            test_path1 = os.path.join(dataset_dir, test_name1)
            test_path2 = os.path.join(dataset_dir, test_name2)

            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.train_ind)
            val_set   = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val']   = val_set
            self.datasets['test']  = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     NBodyDataset(test_path1, self.hparams.noise_std),
                                     NBodyDataset(test_path2, self.hparams.noise_std)]

        elif self.hparams.dataset=='3body-4':
            train_name = f'3body_n_10000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.05.hd5'
            test_name1 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.05_away_0.00_0.01.hd5'
            test_name2 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.05_away_0.01_0.02.hd5'
            train_path = os.path.join(dataset_dir, train_name)
            test_path1 = os.path.join(dataset_dir, test_name1)
            test_path2 = os.path.join(dataset_dir, test_name2)

            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.train_ind)
            val_set   = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val']   = val_set
            self.datasets['test']  = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     NBodyDataset(test_path1, self.hparams.noise_std),
                                     NBodyDataset(test_path2, self.hparams.noise_std)]
        elif self.hparams.dataset=='3body-5':
            train_name = f'3body_n_10000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.10.hd5'
            test_name1 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.10_away_0.00_0.02.hd5'
            test_name2 = f'3body_n_1000_steps_1000_dt_0.01_params_K2_2.00_m1_2.00_m2_2.00_m3_2.00_delta_0.10_away_0.02_0.05.hd5'
            train_path = os.path.join(dataset_dir, train_name)
            test_path1 = os.path.join(dataset_dir, test_name1)
            test_path2 = os.path.join(dataset_dir, test_name2)

            rand_ind = np.random.uniform(0.0,1.0,10000)
            if self.train_ind is None:
                self.train_ind = np.where(rand_ind<=0.80)[0]
                self.val_ind = np.where((rand_ind>0.80) & (rand_ind<=0.90))[0]
                self.test_ind = np.where(rand_ind>=0.90)[0]
            train_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.train_ind)
            val_set   = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.val_ind)
            in_dist_test_set = NBodyDataset(train_path, self.hparams.noise_std, indexes=self.test_ind)

            self.datasets['train'] = train_set
            self.datasets['val']   = val_set
            self.datasets['test']  = [train_set,
                                     val_set,
                                     in_dist_test_set,
                                     NBodyDataset(test_path1, self.hparams.noise_std),
                                     NBodyDataset(test_path2, self.hparams.noise_std)]
        elif self.hparams.dataset=='pixel_pendulum-1':
            train_name = f'pixel_pendulum_n_1296_steps_100_dt_{dt:.2f}_angle_30-170_vel_-2.00-2.00_len_1.20-1.40_g_8.00-12.00.hd5'
            val_name =   f'pixel_pendulum_n_81_steps_100_dt_{dt:.2f}_angle_30-170_vel_-2.00-2.00_len_1.40-1.45_g_12.00-12.50.hd5'
            test_name =  f'pixel_pendulum_n_256_steps_100_dt_{dt:.2f}_angle_30-170_vel_-2.00-2.00_len_1.45-1.50_g_12.50-13.00.hd5'
            train_path=os.path.join(dataset_dir, train_name)
            val_path=os.path.join(dataset_dir, val_name)
            test_path=os.path.join(dataset_dir, test_name)
            self.datasets['train'] = PixelPendulumDataset(train_path, self.hparams.noise_std)
            self.datasets['val'] = PixelPendulumDataset(val_path, self.hparams.noise_std)
            self.datasets['test'] = [self.datasets['val'],
                                     PixelPendulumDataset(test_path, self.hparams.noise_std)]
        else:
            raise Exception(f'Wrong dataset: {self.hparams.dataset}')
        print('Train Dataset length: ', len(self.datasets['train']))
        print('Val Dataset length: ', len(self.datasets['val']))
        print('Test Datasets length: ', [len(dset) for dset in self.datasets['test']])

    def train_dataloader(self):
        return DataLoader(self.datasets['train'], batch_size=self.hparams.batch_size,
                         num_workers=self.hparams.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.datasets['val'], batch_size=self.hparams.batch_size_val, 
                    num_workers=self.hparams.num_workers, shuffle=False)

    def test_dataloader(self):
        test_dloaders = []
        for dset in self.datasets['test']:
            test_dloaders.append(DataLoader(dset, batch_size=self.hparams.batch_size_val, 
                    num_workers=self.hparams.num_workers, shuffle=False))
        return test_dloaders

    def setup(self, stage):
        # prepare a train and validation batches
        self.batch_sample = {}
        self.batch_sample['train'] = next(iter(self.train_dataloader()))
        self.batch_sample['val'] = next(iter(self.val_dataloader()))

        if self.hparams.dataset=='pendulum_var_length':
            self.batch_sample['test_short'] = next(iter(self.test_dataloader()[1]))
            self.batch_sample['test_long'] = next(iter(self.test_dataloader()[2]))
        if self.hparams.dataset=='pendulum-2':
            self.batch_sample['test_out_1'] = next(iter(self.test_dataloader()[3]))
            self.batch_sample['test_out_2'] = next(iter(self.test_dataloader()[4]))
        elif self.hparams.dataset=='3body-2':
            self.batch_sample['test_easy'] = next(iter(self.test_dataloader()[1]))
            self.batch_sample['test_hard'] = next(iter(self.test_dataloader()[2]))
        else:
            # just get a sample from the first test dataset
            # index is 1 because 0 is the val dataset
            self.batch_sample['test'] = next(iter(self.test_dataloader()[1]))

        self.labels_min, self.labels_max = self.datasets['train'].get_labels_min_max()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay)
        # try CosineAnnealingLR
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                mode='min', 
                                                factor=self.hparams.scheduler_factor,
                                                patience=self.hparams.scheduler_patience, 
                                                verbose=True,
                                                threshold=self.hparams.scheduler_threshold,
                                                min_lr=self.hparams.scheduler_min_lr)
        return {'optimizer': optimizer, 
                'lr_scheduler': scheduler,
                'monitor': self.hparams.monitor}

    def log_histogram(self, name, params):
        if self.hparams.use_wandb:
            self.logger.experiment.log({name: wandb.Histogram(params.detach())})
        else:
            self.logger.experiment.add_histogram(name, params, step=self.global_step)

    def log_image(self, name, plt_image):
        if self.hparams.use_wandb:
            self.logger.experiment.log({name: wandb.Image(plt_image)})
        else:
            pass
            # self.logger.experiment.add_image(name, image, self.global_step)

    def log_rec_losses(self, batch, stage, rec_loss_sizes, on_epoch=True, on_step=False):
        # reconstruction losses for longer trajectories.
        max_rollout = np.max(rec_loss_sizes) 
        start = self.get_start(batch, rec_loss_sizes)

        output, target = self.rollout(batch, start=start, rollout_size=max_rollout)
        output = output.to(self.device)
        target = target.to(self.device)
        for step in rec_loss_sizes:
            rec_loss = self.reconstruction_loss(output[:, :step], target[:, :step])
            self.log(f'{stage}/rec/cumm/{step:04d}', rec_loss, on_step=on_step, on_epoch=on_epoch)

            rec_loss = self.reconstruction_loss(output[:, (step-1):step], target[:, (step-1):step])
            self.log(f'{stage}/rec/{step:04d}', rec_loss, on_step=on_step, on_epoch=on_epoch)


    def _on_after_backward(self): # remove preceding underscore to enable
    # used to log parameters grads. pl/wandb has gradient norm logging which is easier/lighter
        if self.hparams.debug and ((self.global_step % self.hparams.log_freq) == 0):
            for name, params in self.named_parameters():
                self.log_histogram(name, params)
                self.log_histogram(f'grads/{name}', params.grad.data)

    def get_start(self, batch, rec_loss_sizes):
        if self.hparams.use_random_start==True:
            length = batch['trajectory'].size(1)
            max_rollout = np.max(rec_loss_sizes) 
            max_start = length - self.hparams.model_input_size - max_rollout
            start = np.random.choice(range(max_start))
        else:
            start = 0
        return start

    def _compute_label_loss(self, labels, mu):
        # loss for labels
        labels_min, labels_max = self.labels_min, self.labels_max

        if self.hparams.sup_loss_type == 'sigmoid':
            pred_scaled = torch.sigmoid(mu[:, :self.num_factors]) \
                            * (labels_max - labels_min) + labels_min
            label_loss = F.l1_loss(pred_scaled, labels)
        elif self.hparams.sup_loss_type == 'sigmoid_parametrized':
            pred_scaled = self.w1 * torch.sigmoid(self.w2 * mu[:, :self.num_factors]) * \
                        (labels_max - labels_min) + labels_min
            label_loss = F.l1_loss(pred_scaled, labels)
            self.log('train/sup/w2', self.w2, prog_bar=False, on_step=False, on_epoch=True)
            self.log('train/sup/w1', self.w1, prog_bar=False, on_step=False, on_epoch=True)
        elif self.hparams.sup_loss_type == 'linear':
            pred_scaled = mu[:, :self.num_factors]
            label_loss = F.l1_loss(pred_scaled, labels)
        elif self.hparams.sup_loss_type == 'linear_scaled':
            pred_scaled = mu[:, :self.num_factors] * \
                            (labels_max - labels_min) + labels_min
            label_loss = F.l1_loss(pred_scaled, labels)
        elif self.hparams.sup_loss_type == 'BCE':
            labels_norm = (labels-labels_min)/(labels_max - labels_min + 1e-6)
            BCE = torch.nn.BCEWithLogitsLoss(reduction='mean')
            label_loss = BCE(mu[:, :self.num_factors], labels_norm)
        else:
            raise Warning('Wrong supervised loss type: ', self.hparams.sup_loss_type)
        return label_loss