from features_ae.network.train_functions import loss_function_ae, get_data, get_z,  \
    setup_housekeeping, housekeeping, get_l_probs
import global_config as gc
import os
import sys
from features_ae.network.encoder import EncoderFC
from features_ae.network.decoder import DecoderFC
from features_ae.network.classifier import ClassifierFC

sys.path.append('..')


class AEConfig:
    def __init__(self):

        self.get_data = get_data
        self.get_z = get_z
        self.get_l_probs = get_l_probs
        self.housekeeping = housekeeping
        self.loss_function_ae = loss_function_ae
        self.dataset_params_train_1 = None
        self.dataset_params_train_2 = None

        self.model_params_1 = {
            'method': None,
            'beta': None,
            'latent_dim': None,
        }

        self.model_params_2 = {
            'method': None,
            'beta': None,
            'latent_dim': None,
        }

        self.training_params = {'config_version': 1,
                                'output_dir': os.path.abspath(gc.cfg['results_path']),
                                'experiment': 'nn_features',
                                'device': 'cpu',
                                'save_loss': True,
                                'save_model': True,
                                'verbose': True,
                                'reverse': False,
                                'invariant': True,
                                # Overwrite existing model.pkl files (throws
                                # warning)
                                'overwrite': True,
                                'housekeeping_keys': ['batch_train_loss_rec_1', 'batch_train_loss_rec_2',
                                                      'batch_train_loss_affinity', 'epoch_train_loss',
                                                      'z_1', 'z_2', 'l_1', 'l_2', 'acc_nn', 'acc_knn'
                                                      ],
                                'keys_to_store': ['batch_train_loss_rec_1', 'batch_train_loss_rec_2',
                                                      'batch_train_loss_affinity', 'epoch_train_loss',
                                                      'z_1', 'z_2', 'l_1', 'l_2',
                                                      'acc_knn_1', 'acc_knn_3', 'rev_1nn_ot', 'rev_3nn_ot',
                                                      'acc_svc_ot', 'rev_svc_ot'
                                                  ],
                                }

    def get_dataset(self, domain, features, shuffle, batch_size, drop_last):
        if domain in ['mnist', 'usps']:
            path = gc.cfg['dataset_path'] + '/nn_features/np/' + domain + '/'
        else:
            path = gc.cfg['dataset_path'] + 'nn_features/np/' + \
                features + '/' + domain + '/'
        size = features[-4:]

        return {'name': 'Domain: {}, Features: {}'.format(domain, features),
                'shuffle': shuffle,
                'feature_dim': int(size),
                'drop_last': drop_last,
                'data_path': path,
                'batch_size': batch_size,
                'verbose': False,
                }

    def setup_network(self, model_params):

        encoder = EncoderFC(
            [self.dataset_params_train_1['feature_dim'], self.training_params['hidden1'], model_params['latent_dim']])
        decoder = DecoderFC([model_params['latent_dim'],
                             self.training_params['hidden1'], self.dataset_params_train_1['feature_dim']])

        encoder2 = EncoderFC(
            [self.dataset_params_train_2['feature_dim'], self.training_params['hidden2'],
             model_params['latent_dim']])
        decoder2 = DecoderFC(
            [model_params['latent_dim'], self.training_params['hidden2'],
            self.dataset_params_train_2['feature_dim']])

        classifier = ClassifierFC([model_params['latent_dim'], 10])
        return encoder, decoder, encoder2, decoder2, classifier
