import json
import math
import os

import numpy as np
from scipy.stats.distributions import uniform


class log_uniform:
    def __init__(self, a, b, base=10):
        self.loc = math.log(a, base)
        self.scale = math.log(b, base) - math.log(a, base)
        self.base = base

    def rvs(self, size=1, random_state=None):
        uniform_values = uniform(loc=self.loc, scale=self.scale)
        exp_values = np.power(
            self.base, uniform_values.rvs(size=size, random_state=random_state)
        )
        if len(exp_values) == 1:
            return exp_values[0]
        else:
            return exp_values


class one_minus_log_uniform:
    def __init__(self, a, b, base=10):
        self.dist = log_uniform(a, b, base)

    def rvs(self, size=1, random_state=None):
        return 1 - self.dist.rvs(size, random_state)


hp_distributions = {
    "cm": {
        "lr": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
    },
    "adam": {
        "lr": log_uniform(1e-4, 0.1),
        "beta1": one_minus_log_uniform(1e-3, 0.5),
        "beta2": one_minus_log_uniform(1e-3, 0.2),
    },
    "rgd": {
        "lr": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),  # uniform(0, 20),  # was 50
        "alpha": uniform(loc=0, scale=1),
    },
    "rgd_eu": {
        "lr": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),
    },
    "pd": {
        "lr": log_uniform(1e-4, 0.1),
        "momentum": one_minus_log_uniform(1e-4, 0.5),
        "delta": log_uniform(0.1, 10),
        "little_a": 2,
        "big_a": uniform(loc=1, scale=1),
    },
}


def main():
    opt = "rgd"
    problem = "cifar100_vit"
    num_trials = 25

    random_seed = 42
    rng = np.random.default_rng(seed=random_seed)
    hps = []
    for _ in range(num_trials):
        hp = {}
        for k, v in hp_distributions[opt].items():
            try:
                hp[k] = v.rvs(random_state=rng)
            except:
                hp[k] = v
        hps.append(hp)

    os.makedirs(f"command_scripts/{problem}", exist_ok=True)
    file_path = f"command_scripts/{problem}/jobs_{opt}.json"

    if os.path.exists(file_path):
        print(f"Error: Command script already exists at {file_path}.")
        print("Please delete it if you would like to overwrite it.")
        return

    contents = [
        {
            "optimizer_name": opt,
            "testproblem": problem,
            "hyperparams": sample,
            "random_seed": random_seed,
        }
        for sample in hps
    ]
    with open(file_path, "w") as f:
        json.dump(contents, f, indent=4)


if __name__ == "__main__":
    main()
