# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Weakly supervised downstream classification task."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from disentanglement_lib.evaluation.metrics import utils
from sklearn.metrics import balanced_accuracy_score
import numpy as np
from six.moves import range
import gin.tf

def compute_kl(z_1, z_2, logvar_1, logvar_2):
  var_2 = np.exp(logvar_2)+1e-6
  kl_div = np.exp(logvar_1-logvar_2) + np.square(z_2-z_1)/var_2 - 1 + logvar_2 - logvar_1
  kl_div = np.maximum(0,kl_div)
  return kl_div

def get_factor_partition(samples, representation_function):
    k=gin.query_parameter("weak_downstream_task.k_est")
    # Return average mean of dimensions with k smallest kl divergences
    z_mean, z_logvar = representation_function(samples[:,0])
    z_mean_2, z_logvar_2 = representation_function(samples[:,1])
    print(z_mean.shape, z_logvar.shape, z_mean_2.shape, z_logvar_2.shape)
    kl_per_point = 0.5*(compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2) + compute_kl(z_mean_2, z_mean, z_logvar_2, z_logvar))
    d = kl_per_point.shape[-1]
    ind_sorted_kl_per_point = np.argsort(kl_per_point,axis=-1)
    independent_ind_sorted_kl_per_point = np.sort(ind_sorted_kl_per_point[:,d-k:], axis=-1)
    shared_ind_sorted_kl_per_point = np.sort(ind_sorted_kl_per_point[:,:d-k], axis=-1)

    shared_broadcast_indices, _ = np.indices(shared_ind_sorted_kl_per_point.shape)
    shared_z_mean = np.zeros_like(z_mean)
    shared_z_mean[shared_broadcast_indices, shared_ind_sorted_kl_per_point] = z_mean[shared_broadcast_indices, shared_ind_sorted_kl_per_point]
    shared_z_mean_2 = np.zeros_like(z_mean_2)
    shared_z_mean_2[shared_broadcast_indices, shared_ind_sorted_kl_per_point] = z_mean_2[shared_broadcast_indices, shared_ind_sorted_kl_per_point]

    independent_broadcast_indices, _ = np.indices(independent_ind_sorted_kl_per_point.shape)
    independent_z_mean = np.zeros_like(z_mean)
    independent_z_mean[independent_broadcast_indices, independent_ind_sorted_kl_per_point] = z_mean[independent_broadcast_indices, independent_ind_sorted_kl_per_point]
    independent_z_mean_2 = np.zeros_like(z_mean_2)
    independent_z_mean_2[independent_broadcast_indices, independent_ind_sorted_kl_per_point] = z_mean_2[independent_broadcast_indices, independent_ind_sorted_kl_per_point]
    if d-k == 1:
        shared_z_mean = shared_z_mean.reshape(-1,1)
        shared_z_mean_2 = shared_z_mean_2.reshape(-1,1)
    independent_representations = np.stack((independent_z_mean,independent_z_mean_2),axis=1)
    return 0.5*(shared_z_mean+shared_z_mean_2), independent_representations

def simple_dynamics(z, ground_truth_data, random_state,
                    return_index=False):
    """Create the pairs."""
    k=gin.query_parameter("weak_downstream_task.k_true")
    
    n_factors = 7
    shuffled_idx_list = np.arange(n_factors)
    independent_index_list = np.sort(shuffled_idx_list[n_factors-k:])
    shared_index_list = np.sort(shuffled_idx_list[:n_factors-k])
    
    independent_labels_1 = np.copy(z[:,independent_index_list])
    idx = -1
    for index in independent_index_list:
        z[:, index] = np.random.choice(range(ground_truth_data.factors_num_values[index]))
    shared_labels = z[:,shared_index_list]
    independent_labels_2 = z[:,independent_index_list]
    if n_factors-k == 1:
        shared_labels = shared_labels.reshape(-1,1)
    if k == 1:
        independent_labels_1 = independent_labels_1.reshape(-1,1)
        independent_labels_2 = independent_labels_2.reshape(-1,1)
    independent_labels = np.vstack((independent_labels_1,independent_labels_2))
    independent_labels = np.expand_dims(independent_labels,axis=0)
    return z, shared_labels, independent_labels

def weak_dataset_generator(ground_truth_data, random_state):
    """Generator fn for the dataset."""
    # We need to hard code the random seed so that the data set can be reset.
    while True:
        sampled_factors = ground_truth_data.sample_factors(1, random_state)
        sampled_observation = ground_truth_data.sample_observations_from_factors(
            sampled_factors, random_state
        )

        next_factors, shared_labels, independent_labels = simple_dynamics(
            sampled_factors,
            ground_truth_data,
            random_state
        )
        next_observation = ground_truth_data.sample_observations_from_factors(
            next_factors, random_state
        )

        yield (np.concatenate((sampled_observation, next_observation),axis=0), shared_labels.squeeze(0), independent_labels.squeeze(0))

def sample_batch(batch_size, sampler):
    """
    Sample a batch so size batch_size with sampler
    """
    batch_obs = []
    batch_shared_factors = []
    batch_independent_factors = []
    for i in range(batch_size):
        sample = next(sampler)
        batch_obs.append(sample[0])
        batch_shared_factors.append(sample[1])
        batch_independent_factors.append(sample[2])
    batch_obs = np.stack(batch_obs)
    batch_shared_factors = np.stack(batch_shared_factors)
    batch_independent_factors = np.stack(batch_independent_factors)
    return batch_independent_factors, batch_shared_factors, batch_obs

def generate_batch_factor_code(ground_truth_data, representation_function,
                               num_points, random_state, batch_size):
    """Sample a single training sample based on a mini-batch of ground-truth data.

    Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observation as input and
      outputs a representation.
    num_points: Number of points to sample.
    random_state: Numpy random state used for randomness.
    batch_size: Batchsize to sample points.

    Returns:
    representations: Codes (num_codes, num_points)-np array.
    factors: Factors generating the codes (num_factors, num_points)-np array.
    """
    sampler = weak_dataset_generator(ground_truth_data, random_state)
    shared_representations = None
    independent_representations = None
    shared_factors = None
    independent_factors = None
    i = 0
    while i < num_points:
        num_points_iter = min(num_points - i, batch_size)
        current_independent_factors,current_shared_factors, current_observations = sample_batch(num_points_iter, sampler)
        curr_shared_representations, curr_independent_representations = get_factor_partition(current_observations, representation_function)
        
#         if len(current_shared_factors.shape)==1:
#             current_shared_factors = current_shared_factors.reshape(-1,1) 
#         if len(current_independent_factors.shape)==1:
#             current_independent_factors = current_independent_factors.reshape(-1,1)
        if i == 0:
            shared_factors = current_shared_factors
            independent_factors = current_independent_factors
            shared_representations = curr_shared_representations
            independent_representations = curr_independent_representations
        else:
            shared_factors = np.vstack((shared_factors, current_shared_factors))
            independent_factors = np.vstack((independent_factors, current_independent_factors))
            shared_representations = np.vstack((
                shared_representations,
                curr_shared_representations
            ))
            independent_representations = np.vstack((
                independent_representations,
                curr_independent_representations
            ))
        i += num_points_iter

    return shared_representations, independent_representations, shared_factors, independent_factors

@gin.configurable(
    "weak_downstream_task",
    blacklist=["ground_truth_data", "representation_function", "random_state",
               "artifact_dir"])
def compute_weak_downstream_task(ground_truth_data,
                            representation_function,
                            random_state,
                            artifact_dir=None,
                            num_train=gin.REQUIRED,
                            num_test=gin.REQUIRED,
                            k_est=gin.REQUIRED,
                            k_true=gin.REQUIRED,
                            batch_size=16):
  """Computes loss of downstream task on shared factors.

  Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observations as input and
      outputs a dim_representation sized representation for each observation.
    random_state: Numpy random state used for randomness.
    artifact_dir: Optional path to directory where artifacts can be saved.
    num_train: Number of points used for training.
    num_test: Number of points used for testing.
    batch_size: Batch size for sampling.

  Returns:
    Dictionary with scores.
  """
  del artifact_dir
  scores = {}
  for train_size in num_train:
    mus_shared_train, mus_independent_train, y_shared_train, y_independent_train = generate_batch_factor_code(
        ground_truth_data, representation_function, train_size, random_state,
        batch_size)
    mus_shared_test, mus_independent_test, y_shared_test, y_independent_test = generate_batch_factor_code(
        ground_truth_data, representation_function, num_test, random_state,
        batch_size)

    predictor_model = utils.make_predictor_fn()
    predictor_name = predictor_model().__class__.__name__
    
    if k_true == 7:
        train_err, test_err = 1.,1.
    else:
        train_err, test_err = _compute_loss(
            mus_shared_train, y_shared_train, mus_shared_test,
            y_shared_test, predictor_model)
    size_string = f'{predictor_name}:{train_size}'
    scores[size_string +
           ":shared_factors:mean_train_accuracy"] = np.mean(train_err)
    scores[size_string +
           ":shared_factors:mean_test_accuracy"] = np.mean(test_err)
    scores[size_string +
           ":shared_factors:min_train_accuracy"] = np.min(train_err)
    scores[size_string + ":shared_factors:min_test_accuracy"] = np.min(test_err)
    if k_true != 7:
        for i in range(len(train_err)):
          scores[size_string +
                 ":shared_factors:train_accuracy_factor_{}".format(i)] = train_err[i]
          scores[size_string + ":shared_factors:test_accuracy_factor_{}".format(i)] = test_err[i]
    for j in range(2):
        # Shared representation, independent factors
        eval_name = f'shared_representation_independent_factors_view{j}'
        
        if k_true==0:
            train_err, test_err = 1.,1.
        else:
            train_err, test_err = _compute_loss(
                mus_shared_train, y_independent_train[:,j,:], 
                mus_shared_test, y_independent_test[:,j,:], predictor_model)
        size_string = f'{predictor_name}:{train_size}'
        scores[size_string +
               f":{eval_name}:mean_train_accuracy"] = np.mean(train_err)
        scores[size_string +
               f":{eval_name}:mean_test_accuracy"] = np.mean(test_err)
        scores[size_string +
               f":{eval_name}:min_train_accuracy"] = np.min(train_err)
        scores[size_string + f":{eval_name}:min_test_accuracy"] = np.min(test_err)
        if k_true != 0:
            for i in range(len(train_err)):
              scores[size_string +
                     f":{eval_name}:train_accuracy_factor_{i}"] = train_err[i]
              scores[size_string + f":{eval_name}:test_accuracy_factor_{i}"] = test_err[i]
        # Independent representation, shared factors
        eval_name = f'independent_representation_shared_factors_view{j}'
        if k_true == 7:
            train_err, test_err = 1.,1.
        else:
            train_err, test_err = _compute_loss(
                mus_independent_train[:,j,:], y_shared_train, 
                mus_independent_test[:,j,:], y_shared_test, predictor_model)
        size_string = f'{predictor_name}:{train_size}'
        scores[size_string +
               f":{eval_name}:mean_train_accuracy"] = np.mean(train_err)
        scores[size_string +
               f":{eval_name}:mean_test_accuracy"] = np.mean(test_err)
        scores[size_string +
               f":{eval_name}:min_train_accuracy"] = np.min(train_err)
        scores[size_string + f":{eval_name}:min_test_accuracy"] = np.min(test_err)
        if k_true != 7:
            for i in range(len(train_err)):
              scores[size_string +
                     f":{eval_name}:train_accuracy_factor_{i}"] = train_err[i]
              scores[size_string + f":{eval_name}:test_accuracy_factor_{i}"] = test_err[i]
        # Independent representation, independent factors
        eval_name = f'independent_representation_independent_factors_view{j}'
        if k_true==0:
            train_err, test_err = 1.,1.
        else:
            train_err, test_err = _compute_loss(
                mus_independent_train[:,j,:], y_independent_train[:,j,:], 
                mus_independent_test[:,j,:], y_independent_test[:,j,:], predictor_model)
        size_string = f'{predictor_name}:{train_size}'
        scores[size_string +
               f":{eval_name}:mean_train_accuracy"] = np.mean(train_err)
        scores[size_string +
               f":{eval_name}:mean_test_accuracy"] = np.mean(test_err)
        scores[size_string +
               f":{eval_name}:min_train_accuracy"] = np.min(train_err)
        scores[size_string + f":{eval_name}:min_test_accuracy"] = np.min(test_err)
        if k_true != 0:
            for i in range(len(train_err)):
              scores[size_string +
                     f":{eval_name}:train_accuracy_factor_{i}"] = train_err[i]
              scores[size_string + f":{eval_name}:test_accuracy_factor_{i}"] = test_err[i]

  return scores


def _compute_loss(x_train, y_train, x_test, y_test, predictor_fn):
  """Compute average accuracy for train and test set."""
  num_factors = y_train.shape[1]
  train_loss = []
  test_loss = []
  for i in range(num_factors):
    model = predictor_fn()
    model.fit(x_train, y_train[:,i])
    train_loss.append(balanced_accuracy_score(y_train[:,i],model.predict(x_train), adjusted=True))
    test_loss.append(balanced_accuracy_score(y_test[:,i],model.predict(x_test), adjusted=True))
    
  return train_loss, test_loss
