from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import sys

from scipy.io import loadmat
import numpy as np
import tensorflow as tf

from generic_utils import random_seed
from generic_utils import data_dir


def test_real_dataset(create_obj_func, src_name=None, trg_name=None, show=False, block_figure_on_end=False):
    print('Running {} ...'.format(os.path.basename(__file__)))

    if src_name is None:
        if len(sys.argv) > 2:
            src_name = sys.argv[2]
        else:
            raise Exception('Not specify source dataset')
    if trg_name is None:
        if len(sys.argv) > 3:
            trg_name = sys.argv[3]
        else:
            raise Exception('Not specify trgget dataset')

    np.random.seed(random_seed())
    tf.set_random_seed(random_seed())
    tf.reset_default_graph()

    print("========== Test on real data ==========")
    users_params = dict()
    users_params = parse_arguments(users_params)

    data_format = 'mat'

    if 'format' in users_params:
        data_format, users_params = extract_param('format', data_format, users_params)

    if len(users_params['data_path']) == 0:
        data_path = data_dir()
    else:
        data_path = users_params['data_path']

    src_train_file_name = os.path.join(data_path, src_name + '_train.' + data_format)
    src_test_file_name = os.path.join(data_path, src_name + '_test.' + data_format)

    trg_train_file_name = os.path.join(data_path, trg_name + '_train.' + data_format)
    trg_test_file_name = os.path.join(data_path, trg_name + '_test.' + data_format)

    print(src_train_file_name)
    print(src_test_file_name)
    print(trg_train_file_name)
    print(trg_test_file_name)

    if not os.path.exists(src_train_file_name):
        raise Exception('File source train not found')
    if not os.path.exists(src_test_file_name):
        raise Exception('File source test not found')
    if not os.path.exists(trg_train_file_name):
        raise Exception('File target train not found')
    if not os.path.exists(trg_test_file_name):
        raise Exception('File target test not found')

    if data_format == 'mat':
        x_src_train, y_src_train = load_mat_file_single_label(src_train_file_name)
        x_src_test, y_src_test = load_mat_file_single_label(src_test_file_name)

        x_trg_train, y_trg_train = load_mat_file_single_label(trg_train_file_name)
        x_trg_test, y_trg_test = load_mat_file_single_label(trg_test_file_name)

    print('Source training set:', x_src_train.shape)
    print('Source testing set:', y_src_test.shape)
    print('Source training labels:', np.unique(y_src_train))
    print('Source testing labels:', np.unique(y_src_test))

    print('Target training set:', x_trg_train.shape)
    print('Target testing set:', y_trg_test.shape)
    print('Target training labels:', np.unique(y_trg_train))
    print('Target testing labels:', np.unique(y_trg_test))

    print('Train', x_src_train.min(), x_src_train.max(), x_trg_train.min(), x_trg_train.max())
    print('Test', x_src_test.min(), x_src_test.max(), x_trg_test.min(), x_trg_test.max())

    if users_params['cast_data']:
        print('Before casting: Train', x_src_train.min(), x_src_train.max(), x_trg_train.min(), x_trg_train.max())
        print('Before casting: Test', x_src_test.min(), x_src_test.max(), x_trg_test.min(), x_trg_test.max())

        print('Casting x to [-1, 1] float ...')
        if 'mnist32_60_10' not in src_name:
            x_src_train = u2t(x_src_train)
            x_src_test = u2t(x_src_test)
        else:
            print("Do not cast mnist32_60_10 data")

        if 'mnist32_60_10' not in trg_name:
            x_trg_train = u2t(x_trg_train)
            x_trg_test = u2t(x_trg_test)
        else:
            print("Do not cast mnist32_60_10 data")

        print('After casting: Train', x_src_train.min(), x_src_train.max(), x_trg_train.min(), x_trg_train.max())
        print('After casting: Test', x_src_test.min(), x_src_test.max(), x_trg_test.min(), x_trg_test.max())

    print('users_params:', users_params)

    if ('sparse' not in users_params.keys()) and (data_format == 'libsvm'):
        x_src_train = x_src_train.toarray()
        x_src_test = x_src_test.toarray()
        x_trg_train = x_trg_train.toarray()
        x_trg_test = x_trg_test.toarray()

    learner = create_obj_func(users_params)
    learner.dim_src = x_src_train.shape[1:]
    learner.dim_trg = x_trg_train.shape[1:]

    learner.x_trg_test = x_trg_test
    learner.y_trg_test = y_trg_test
    learner.x_src_test = x_src_test
    learner.y_src_test = y_src_test

    x_src_full = np.concatenate([x_src_train, x_src_test])
    y_src_full = np.concatenate([y_src_train, y_src_test])

    x_trg_full = np.concatenate([x_trg_train, x_trg_test])
    y_trg_full = np.concatenate([y_trg_train, y_trg_test])
    learner.x_src_full = x_src_full
    learner.x_trg_full = x_trg_full
    learner.y_src_full = y_src_full
    learner.y_trg_full = y_trg_full

    learner._init(y_src_train, y_trg_train)
    learner._build_model()
    learner._fit_loop(x_src=x_src_train, y_src=y_src_train, x_trg=x_trg_train, y_trg=y_trg_train)


def main_func(
        create_obj_func,
        choice_default=0,
        src_name_default='svmguide1',
        trg_name_default='svmguide1',
        params_gridsearch=None,
        attribute_names=None,
        num_workers=4,
        file_config=None,
        run_exp=False,
        keep_vars=[],
        **kwargs):

    if not run_exp:
        choice_lst = [0, 1, 2]  # this is the value after the file name
        src_name = src_name_default
        trg_name = trg_name_default
    elif len(sys.argv) > 1:
        choice_lst = [int(sys.argv[1])]
        src_name = None
        trg_name = None
    else:
        choice_lst = [choice_default]
        src_name = src_name_default
        trg_name = trg_name_default

    for choice in choice_lst:
        if choice == 0:
            pass
        elif choice == 1:
            test_real_dataset(create_obj_func, src_name, trg_name, show=False, block_figure_on_end=run_exp)


def parse_arguments(params, as_array=False):
    for it in range(4, len(sys.argv), 2):
        params[sys.argv[it]] = parse_argument(sys.argv[it + 1], as_array)
    return params


def parse_argument(string, as_array=False):
    try:
        result = int(string)
    except ValueError:
        try:
            result = float(string)
        except ValueError:
            if str.lower(string) == 'true':
                result = True
            elif str.lower(string) == 'false':
                result = False
            elif string == "[]":
                return []
            elif ('|' in string) and ('[' in string) and (']' in string):
                result = [float(item) for item in string[1:-1].split('|')]
                return result
            elif (',' in string) and ('(' in string) and (')' in string):
                split = string[1:-1].split(',')
                result = float(split[0]) ** np.arange(float(split[1]), float(split[2]), float(split[3]))
                return result
            else:
                result = string

    return [result] if as_array else result


def resolve_conflict_params(primary_params, secondary_params):
    for key in primary_params.keys():
        if key in secondary_params.keys():
            del secondary_params[key]
    return secondary_params


def extract_param(key, value, params_gridsearch, scalar=False):
    if key in params_gridsearch.keys():
        value = params_gridsearch[key]
        del params_gridsearch[key]
        if scalar and (value is not None):
            value = value[0]
    return value, params_gridsearch


def load_mat_file_single_label(filename):
    data = loadmat(filename)
    x = data['feas']
    y = data['labels'][0]
    return x, y


def u2t(x):
    return x.astype('float32') / 255 * 2 - 1
