# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Weakly supervised downstream classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from disentanglement_lib.evaluation.metrics import utils
import numpy as np
from six.moves import range
import gin.tf

def compute_kl(z_1, z_2, logvar_1, logvar_2):
  var_2 = np.exp(logvar_2)+1e-6
  kl_div = np.exp(logvar_1-logvar_2) + np.square(z_2-z_1)/var_2 - 1 + logvar_2 - logvar_1
  kl_div = np.maximum(0,kl_div)
  return kl_div

def get_shared_factors(samples, representation_function):
    k=gin.query_parameter("weak_downstream_task.k")
    n_factors = 7
    # Return average mean of dimensions with k smallest kl divergences
    z_mean, z_logvar = representation_function(samples[:,0])
    z_mean_2, z_logvar_2 = representation_function(samples[:,1])
    kl_per_point = 0.5*(compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2) + compute_kl(z_mean_2, z_mean, z_logvar_2, z_logvar))
    ind_sorted_kl_per_point = np.argsort(kl_per_point,axis=-1)
    shared_z_mean = np.take_along_axis(z_mean, ind_sorted_kl_per_point, axis=-1)[:,:n_factors-k]
    shared_z_mean_2 = np.take_along_axis(z_mean_2, ind_sorted_kl_per_point, axis=-1)[:,:n_factors-k]
    if n_factors-k == 1:
        shared_z_mean = shared_z_mean.reshape(-1,1)
        shared_z_mean_2 = shared_z_mean_2.reshape(-1,1)
    return 0.5*(shared_z_mean+shared_z_mean_2)

def simple_dynamics(z, ground_truth_data, random_state,
                    return_index=False):
    """Create the pairs."""
    k=gin.query_parameter("weak_downstream_task.k")
    n_factors = 7
    index_list = np.arange(n_factors-k, z.shape[-1])
    idx = -1
    for index in index_list:
        z[:, index] = np.random.choice(range(ground_truth_data.factors_num_values[index]))
    labels = z[:,:n_factors-k]
    if n_factors-k == 1:
        labels = labels.reshape(-1,1)
    return z, labels

def weak_dataset_generator(ground_truth_data, random_state):
    """Generator fn for the dataset."""
    # We need to hard code the random seed so that the data set can be reset.
    while True:
        sampled_factors = ground_truth_data.sample_factors(1, random_state)
        sampled_observation = ground_truth_data.sample_observations_from_factors(
            sampled_factors, random_state
        )

        next_factors, labels = simple_dynamics(
            sampled_factors,
            ground_truth_data,
            random_state
        )
        next_observation = ground_truth_data.sample_observations_from_factors(
            next_factors, random_state
        )

        yield (np.concatenate((sampled_observation, next_observation),axis=0), labels.squeeze(0))

def sample_batch(batch_size, sampler):
    """
    Sample a batch so size batch_size with sampler
    """
    batch_obs = []
    batch_shared_factors = []
    for i in range(batch_size):
        sample = next(sampler)
        batch_obs.append(sample[0])
        batch_shared_factors.append(sample[1])
    batch_obs = np.stack(batch_obs)
    batch_shared_factors = np.stack(batch_shared_factors)
    return batch_shared_factors, batch_obs

def generate_batch_factor_code(ground_truth_data, representation_function,
                               num_points, random_state, batch_size):
    """Sample a single training sample based on a mini-batch of ground-truth data.

    Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observation as input and
      outputs a representation.
    num_points: Number of points to sample.
    random_state: Numpy random state used for randomness.
    batch_size: Batchsize to sample points.

    Returns:
    representations: Codes (num_codes, num_points)-np array.
    factors: Factors generating the codes (num_factors, num_points)-np array.
    """
    sampler = weak_dataset_generator(ground_truth_data, random_state)
    representations = None
    factors = None
    i = 0
    while i < num_points:
        num_points_iter = min(num_points - i, batch_size)
        current_factors, current_observations = sample_batch(num_points_iter, sampler)
        curr_representations = get_shared_factors(current_observations, representation_function)
        if len(current_factors.shape)==1:
            current_factors = current_factors.reshape(-1,1) 
        if i == 0:
            factors = current_factors
            representations = curr_representations
        else:
            factors = np.vstack((factors, current_factors))
            representations = np.vstack((representations,
                                       curr_representations))
        i += num_points_iter

    return np.transpose(representations), np.transpose(factors)

@gin.configurable(
    "weak_downstream_task",
    blacklist=["ground_truth_data", "representation_function", "random_state",
               "artifact_dir"])
def compute_weak_downstream_task(ground_truth_data,
                            representation_function,
                            random_state,
                            artifact_dir=None,
                            num_train=gin.REQUIRED,
                            num_test=gin.REQUIRED,
                            k=gin.REQUIRED,
                            batch_size=16):
  """Computes loss of downstream task on shared factors.

  Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    random_state: Numpy random state used for randomness.
    artifact_dir: Optional path to directory where artifacts can be saved.
    num_train: Number of points used for training.
    num_test: Number of points used for testing.
    batch_size: Batch size for sampling.

  Returns:
    Dictionary with scores.
  """
  del artifact_dir
  scores = {}
  for train_size in num_train:
    mus_train, ys_train = generate_batch_factor_code(
        ground_truth_data, representation_function, train_size, random_state,
        batch_size)
    mus_test, ys_test = generate_batch_factor_code(
        ground_truth_data, representation_function, num_test, random_state,
        batch_size)
    predictor_model = utils.make_predictor_fn()
    predictor_name = predictor_model().__class__.__name__

    train_err, test_err = _compute_loss(
        np.transpose(mus_train), ys_train, np.transpose(mus_test),
        ys_test, predictor_model)
    size_string = f'{predictor_name}:{train_size}'
    scores[size_string +
           ":mean_train_accuracy"] = np.mean(train_err)
    scores[size_string +
           ":mean_test_accuracy"] = np.mean(test_err)
    scores[size_string +
           ":min_train_accuracy"] = np.min(train_err)
    scores[size_string + ":min_test_accuracy"] = np.min(test_err)
    for i in range(len(train_err)):
      scores[size_string +
             ":train_accuracy_factor_{}".format(i)] = train_err[i]
      scores[size_string + ":test_accuracy_factor_{}".format(i)] = test_err[i]

  return scores


def _compute_loss(x_train, y_train, x_test, y_test, predictor_fn):
  """Compute average accuracy for train and test set."""
  num_factors = y_train.shape[0]
  train_loss = []
  test_loss = []
  for i in range(num_factors):
    model = predictor_fn()
    model.fit(x_train, y_train[i, :])
    train_loss.append(np.mean(model.predict(x_train) == y_train[i, :]))
    test_loss.append(np.mean(model.predict(x_test) == y_test[i, :]))
  return train_loss, test_loss
