import copy
import math

import numpy as np
import tensorflow as tf
import deepobs.tensorflow
from scipy.stats.distributions import uniform

import tf_optimizers


class log_uniform:
    def __init__(self, a, b, base=10):
        self.loc = math.log(a, base)
        self.scale = math.log(b, base) - math.log(a, base)
        self.base = base

    def rvs(self, size=1, random_state=None):
        uniform_values = uniform(loc=self.loc, scale=self.scale)
        exp_values = np.power(
            self.base, uniform_values.rvs(size=size, random_state=random_state)
        )
        if len(exp_values) == 1:
            return exp_values[0]
        else:
            return exp_values


class one_minus_log_uniform:
    def __init__(self, a, b, base=10):
        self.dist = log_uniform(a, b, base)

    def rvs(self, size=1, random_state=None):
        return 1 - self.dist.rvs(size, random_state)


optimizers = {
    "cm": tf.train.MomentumOptimizer,
    "adam": tf.train.AdamOptimizer,
    "rgd": tf_optimizers.RGD,
    "rgd_eu": tf_optimizers.RGD,
    "pd": tf_optimizers.PowerDescent,
}

hyperparams = {
    "cm": {
        "learning_rate": {"type": float},
        "momentum": {"type": float},
        "use_nesterov": {"type": bool, "default": False},
    },
    "adam": {
        "learning_rate": {"type": float},
        "beta1": {"type": float},
        "beta2": {"type": float},
    },
    "rgd": {
        "learning_rate": {"type": float},
        "momentum": {"type": float},
        "delta": {"type": float},
        "integrator": {"type": str, "default": "leapfrog"},
        "alpha": {"type": float},
    },
    "rgd_eu": {
        "learning_rate": {"type": float},
        "momentum": {"type": float},
        "delta": {"type": float},
        "integrator": {"type": str, "default": "symplectic_euler"},
    },
    "pd": {
        "learning_rate": {"type": float},
        "momentum": {"type": float},
        "delta": {"type": float},
        "little_a": {"type": float, "default": 2},
        "big_a": {"type": float},
    },
}

hps_oneshot = {
    "cm": {"learning_rate": 1e-2, "momentum": 0.99},
    "adam": {"learning_rate": 1e-3, "beta1": 0.9, "beta2": 0.999},
    "rgd": {"learning_rate": 1e-2, "momentum": 0.99},
    "rgd_eu": {"learning_rate": 1e-2, "momentum": 0.99},
}

hp_distributions = {
    "cm": {
        "learning_rate": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
    },
    "adam": {
        "learning_rate": log_uniform(1e-4, 0.1),
        "beta1": one_minus_log_uniform(1e-3, 0.5),
        "beta2": one_minus_log_uniform(1e-3, 0.2),
    },
    "rgd": {
        "learning_rate": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),  # uniform(0, 20),
        "alpha": uniform(loc=0, scale=1),
    },
    "rgd_eu": {
        "learning_rate": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),
    },
    "pd": {
        "learning_rate": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),
        "big_a": uniform(loc=1, scale=1),
    },
}

problems = {
    "quadratic_deep": {"num_epochs": 100, "batch_size": 128},
    "mnist_mlp": {"num_epochs": 100, "batch_size": 128},
    "fmnist_vae": {"num_epochs": 100, "batch_size": 64},
    "cifar100_allcnnc": {"num_epochs": 350, "batch_size": 256},
}
