from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range
import sklearn
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegressionCV
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'


def generate_batch_factor_code(ground_truth_data, representation_function,
                               num_points, random_state, batch_size):
  """Sample a single training sample based on a mini-batch of ground-truth data.
  Args:
    ground_truth_data: GroundTruthData to be sampled from.
    representation_function: Function that takes observation as input and
      outputs a representation.
    num_points: Number of points to sample.
    random_state: Numpy random state used for randomness.
    batch_size: Batchsize to sample points.
  Returns:
    representations: Codes (num_codes, num_points)-np array.
    factors: Factors generating the codes (num_factors, num_points)-np array.
  """
  representations = None
  factors = None
  i = 0
  while i < num_points:
    num_points_iter = min(num_points - i, batch_size)
    current_factors, current_observations = \
        ground_truth_data.sample(num_points_iter, random_state)
    if i == 0:
      factors = current_factors
      representations = representation_function(current_observations)
    else:
      factors = np.vstack((factors, current_factors))
      representations = np.vstack((representations,
                                   representation_function(
                                       current_observations)))
    i += num_points_iter
  return np.transpose(representations), np.transpose(factors)


def obtain_representation(observations, representation_function, batch_size):
  """"Obtain representations from observations.
  Args:
    observations: Observations for which we compute the representation.
    representation_function: Function that takes observation as input and
      outputs a representation.
    batch_size: Batch size to compute the representation.
  Returns:
    representations: Codes (num_codes, num_points)-Numpy array.
  """
  representations = None
  num_points = observations.shape[0]
  i = 0
  while i < num_points:
    num_points_iter = min(num_points - i, batch_size)
    current_observations = observations[i:i + num_points_iter]
    if i == 0:
      representations = representation_function(current_observations)
    else:
      representations = np.vstack((representations,
                                   representation_function(
                                       current_observations)))
    i += num_points_iter
  return np.transpose(representations)


def discrete_mutual_info(mus, ys):
  """Compute discrete mutual information."""
  num_codes = mus.shape[0]
  num_factors = ys.shape[0]
  m = np.zeros([num_codes, num_factors])
  for i in range(num_codes):
    for j in range(num_factors):
      m[i, j] = sklearn.metrics.mutual_info_score(ys[j, :], mus[i, :])
  return m


def discrete_entropy(ys):
  """Compute discrete mutual information."""
  num_factors = ys.shape[0]
  h = np.zeros(num_factors)
  for j in range(num_factors):
    h[j] = sklearn.metrics.mutual_info_score(ys[j, :], ys[j, :])
  return h

def _histogram_discretize(target, num_bins):
  """Discretization based on histograms."""
  discretized = np.zeros_like(target)
  for i in range(target.shape[0]):
    discretized[i, :] = np.digitize(target[i, :], np.histogram(
        target[i, :], num_bins)[1][:-1])
  return discretized

def make_discretizer(target, num_bins = 20,
                     discretizer_fn = _histogram_discretize):
  """Wrapper that creates discretizers."""
  return discretizer_fn(target, num_bins)





def normalize_data(data, mean=None, stddev=None):
  if mean is None:
    mean = np.mean(data, axis=1)
  if stddev is None:
    stddev = np.std(data, axis=1)
  return (data - mean[:, np.newaxis]) / stddev[:, np.newaxis], mean, stddev



def make_predictor_fn(predictor_fn):
  """Wrapper that creates classifiers."""
  return predictor_fn


def logistic_regression_cv():
  """Logistic regression with 5 folds cross validation."""
  return LogisticRegressionCV(Cs=10, cv=KFold(n_splits=5))


def gradient_boosting_classifier():
  """Default gradient boosting classifier."""
  return GradientBoostingClassifier()

#
# ## Modified Inception Network for FID score
# class InceptionV3(nn.Module):
#   """Pretrained InceptionV3 network returning feature maps"""
#
#   # Index of default block of inception to return,
#   # corresponds to output of final average pooling
#   DEFAULT_BLOCK_INDEX = 3
#
#   # Maps feature dimensionality to their output blocks indices
#   BLOCK_INDEX_BY_DIM = {
#       64: 0,   # First max pooling features
#       192: 1,  # Second max pooling featurs
#       768: 2,  # Pre-aux classifier features
#       2048: 3  # Final average pooling features
#   }
#
#   def __init__(self,
#                  output_blocks=[DEFAULT_BLOCK_INDEX],
#                  resize_input=True,
#                  normalize_input=True,
#                  requires_grad=False,
#                  use_fid_inception=True):
#     """Build pretrained InceptionV3
#     Parameters
#     ----------
#     output_blocks : list of int
#         Indices of blocks to return features of. Possible values are:
#             - 0: corresponds to output of first max pooling
#             - 1: corresponds to output of second max pooling
#             - 2: corresponds to output which is fed to aux classifier
#             - 3: corresponds to output of final average pooling
#     resize_input : bool
#         If true, bilinearly resizes input to width and height 299 before
#         feeding input to model. As the network without fully connected
#         layers is fully convolutional, it should be able to handle inputs
#         of arbitrary size, so resizing might not be strictly needed
#     normalize_input : bool
#         If true, scales the input from range (0, 1) to the range the
#         pretrained Inception network expects, namely (-1, 1)
#     requires_grad : bool
#         If true, parameters of the model require gradients. Possibly useful
#         for finetuning the network
#     use_fid_inception : bool
#         If true, uses the pretrained Inception model used in Tensorflow's
#         FID implementation. If false, uses the pretrained Inception model
#         available in torchvision. The FID Inception model has different
#         weights and a slightly different structure from torchvision's
#         Inception model. If you want to compute FID scores, you are
#         strongly advised to set this parameter to true to get comparable
#         results.
#     """
#     super(InceptionV3, self).__init__()
#
#     self.resize_input = resize_input
#     self.normalize_input = normalize_input
#     self.output_blocks = sorted(output_blocks)
#     self.last_needed_block = max(output_blocks)
#
#     assert self.last_needed_block <= 3, \
#         'Last possible output block index is 3'
#
#     self.blocks = nn.ModuleList()
#
#     if use_fid_inception:
#         inception = fid_inception_v3()
#     else:
#         inception = models.inception_v3(pretrained=True)
#
#     # Block 0: input to maxpool1
#     block0 = [
#         inception.Conv2d_1a_3x3,
#         inception.Conv2d_2a_3x3,
#         inception.Conv2d_2b_3x3,
#         nn.MaxPool2d(kernel_size=3, stride=2)
#     ]
#     self.blocks.append(nn.Sequential(*block0))
#
#     # Block 1: maxpool1 to maxpool2
#     if self.last_needed_block >= 1:
#         block1 = [
#             inception.Conv2d_3b_1x1,
#             inception.Conv2d_4a_3x3,
#             nn.MaxPool2d(kernel_size=3, stride=2)
#         ]
#         self.blocks.append(nn.Sequential(*block1))
#
#     # Block 2: maxpool2 to aux classifier
#     if self.last_needed_block >= 2:
#         block2 = [
#             inception.Mixed_5b,
#             inception.Mixed_5c,
#             inception.Mixed_5d,
#             inception.Mixed_6a,
#             inception.Mixed_6b,
#             inception.Mixed_6c,
#             inception.Mixed_6d,
#             inception.Mixed_6e,
#         ]
#         self.blocks.append(nn.Sequential(*block2))
#
#     # Block 3: aux classifier to final avgpool
#     if self.last_needed_block >= 3:
#         block3 = [
#             inception.Mixed_7a,
#             inception.Mixed_7b,
#             inception.Mixed_7c,
#             nn.AdaptiveAvgPool2d(output_size=(1, 1))
#         ]
#         self.blocks.append(nn.Sequential(*block3))
#
#     for param in self.parameters():
#         param.requires_grad = requires_grad
#
#   def forward(self, inp):
#     """Get Inception feature maps
#     Parameters
#     ----------
#     inp : torch.autograd.Variable
#         Input tensor of shape Bx3xHxW. Values are expected to be in
#         range (0, 1)
#     Returns
#     -------
#     List of torch.autograd.Variable, corresponding to the selected output
#     block, sorted ascending by index
#     """
#     outp = []
#     x = inp
#
#     if self.resize_input:
#         x = F.interpolate(x,
#                           size=(299, 299),
#                           mode='bilinear',
#                           align_corners=False)
#
#     if self.normalize_input:
#         x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)
#
#     for idx, block in enumerate(self.blocks):
#         x = block(x)
#         if idx in self.output_blocks:
#             outp.append(x)
#
#         if idx == self.last_needed_block:
#             break
#
#     return outp
#
# class FIDInceptionA(models.inception.InceptionA):
#   """InceptionA block patched for FID computation"""
#   def __init__(self, in_channels, pool_features):
#     super(FIDInceptionA, self).__init__(in_channels, pool_features)
#
#   def forward(self, x):
#     branch1x1 = self.branch1x1(x)
#
#     branch5x5 = self.branch5x5_1(x)
#     branch5x5 = self.branch5x5_2(branch5x5)
#
#     branch3x3dbl = self.branch3x3dbl_1(x)
#     branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
#     branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
#
#     # Patch: Tensorflow's average pool does not use the padded zero's in
#     # its average calculation
#     branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
#                                count_include_pad=False)
#     branch_pool = self.branch_pool(branch_pool)
#
#     outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
#     return torch.cat(outputs, 1)
#
#
# class FIDInceptionC(models.inception.InceptionC):
#   """InceptionC block patched for FID computation"""
#   def __init__(self, in_channels, channels_7x7):
#     super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
#
#   def forward(self, x):
#     branch1x1 = self.branch1x1(x)
#
#     branch7x7 = self.branch7x7_1(x)
#     branch7x7 = self.branch7x7_2(branch7x7)
#     branch7x7 = self.branch7x7_3(branch7x7)
#
#     branch7x7dbl = self.branch7x7dbl_1(x)
#     branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
#     branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
#     branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
#     branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
#
#     # Patch: Tensorflow's average pool does not use the padded zero's in
#     # its average calculation
#     branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
#                                count_include_pad=False)
#     branch_pool = self.branch_pool(branch_pool)
#
#     outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
#     return torch.cat(outputs, 1)
#
#
# class FIDInceptionE_1(models.inception.InceptionE):
#   """First InceptionE block patched for FID computation"""
#   def __init__(self, in_channels):
#     super(FIDInceptionE_1, self).__init__(in_channels)
#
#   def forward(self, x):
#     branch1x1 = self.branch1x1(x)
#
#     branch3x3 = self.branch3x3_1(x)
#     branch3x3 = [
#         self.branch3x3_2a(branch3x3),
#         self.branch3x3_2b(branch3x3),
#     ]
#     branch3x3 = torch.cat(branch3x3, 1)
#
#     branch3x3dbl = self.branch3x3dbl_1(x)
#     branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
#     branch3x3dbl = [
#         self.branch3x3dbl_3a(branch3x3dbl),
#         self.branch3x3dbl_3b(branch3x3dbl),
#     ]
#     branch3x3dbl = torch.cat(branch3x3dbl, 1)
#
#     # Patch: Tensorflow's average pool does not use the padded zero's in
#     # its average calculation
#     branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
#                                count_include_pad=False)
#     branch_pool = self.branch_pool(branch_pool)
#
#     outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
#     return torch.cat(outputs, 1)
#
#
# class FIDInceptionE_2(models.inception.InceptionE):
#   """Second InceptionE block patched for FID computation"""
#   def __init__(self, in_channels):
#     super(FIDInceptionE_2, self).__init__(in_channels)
#
#   def forward(self, x):
#     branch1x1 = self.branch1x1(x)
#
#     branch3x3 = self.branch3x3_1(x)
#     branch3x3 = [
#         self.branch3x3_2a(branch3x3),
#         self.branch3x3_2b(branch3x3),
#     ]
#     branch3x3 = torch.cat(branch3x3, 1)
#
#     branch3x3dbl = self.branch3x3dbl_1(x)
#     branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
#     branch3x3dbl = [
#         self.branch3x3dbl_3a(branch3x3dbl),
#         self.branch3x3dbl_3b(branch3x3dbl),
#     ]
#     branch3x3dbl = torch.cat(branch3x3dbl, 1)
#
#     # Patch: The FID Inception model uses max pooling instead of average
#     # pooling. This is likely an error in this specific Inception
#     # implementation, as other Inception models use average pooling here
#     # (which matches the description in the paper).
#     branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
#     branch_pool = self.branch_pool(branch_pool)
#
#     outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
#     return torch.cat(outputs, 1)
#
# def fid_inception_v3():
#   """Build pretrained Inception model for FID computation
#   The Inception model for FID computation uses a different set of weights
#   and has a slightly different structure than torchvision's Inception.
#   This method first constructs torchvision's Inception and then patches the
#   necessary parts that are different in the FID Inception model.
#   """
#   inception = models.inception_v3(num_classes=1008,
#                                   aux_logits=False,
#                                   pretrained=False)
#   inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
#   inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
#   inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
#   inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
#   inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
#   inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
#   inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
#   inception.Mixed_7b = FIDInceptionE_1(1280)
#   inception.Mixed_7c = FIDInceptionE_2(2048)
#
#   state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
#   inception.load_state_dict(state_dict)
#   return inception
#
#

