from imp import reload
from typing import List
from collections import OrderedDict
import os
import json
from pprint import pprint


class ObjectView:
    def __init__(self, d):
        self.__dict__ = d


def append_record(record, file_name):
    # 1. Read file contents
    with open(file_name, "r") as file:
        data = json.load(file)

    # 2. Update json object
    data.append(record)

    # 3. Write json file
    with open(file_name, "w") as file:
        json.dump(data, file)


def turn_dict_to_file_name(dictionary: dict, keys: List) -> str:
    """
    Turn a dictionary into file name
    """

    return "_".join([str(dictionary[key]) for key in keys])


def load_params(
    args,
    wandb,
    n_s,
    df_train,
    delta_data=None,
    scaler=None,
    mu_gmm=None,
    sigma_gmm=None,
    beta=None,
):

    params = OrderedDict(
        df_=df_train,
        seed=args.seed,
        epsilon=args.epsilon * (1 - args.epsilon_iw_perc),
        delta=delta_data,
        n_s=n_s,
        wandb=wandb,
    )

    if args.model_class == "PRIVBAYES":
        params.update(
            OrderedDict(
                category_threshold=args.category_threshold,
                k=args.num_parents,
                histogram_bins=args.histogram_bins,
                # other_categorical=args.other_categorical,
                # keys=args.keys,
                scaler=scaler,
                in_path=f"{args.main_dir}/data/{args.fn_csv}.csv",
                out_dir=args.main_dir,
            )
        )

    if args.model_class == "UNIFORM":
        params.update(
            OrderedDict(
                mu_gmm=mu_gmm,
                sigma_gmm=sigma_gmm,
                beta=beta,
                scaler=scaler,
                target=args.fn_target_csv is not None,
                args=args,
            )
        )
    elif args.model_class in ["DPCGAN", "CGAN", "DPGAN", "GAN"]:
        params.update(
            OrderedDict(
                Z_DIM=args.gan_hidden_dim,
                reload=args.reload,
            )
        )

    return params


def log(args, method, task, loss):

    if not os.path.exists(f"{args.main_dir}/logs"):
        os.makedirs("logs")

    full_log_name = f"{args.main_dir}/logs/{args.log_file}.json"
    if not os.path.exists(full_log_name):
        with open(full_log_name, "w") as file:
            file.write("[]")

    try:
        params = args.__dict__["_content"]
    except KeyError:
        params = args.__dict__

    dict_to_save = {**params, **{"loss": loss, "task": task, "method": method}}
    dict_to_save = {k: str(v) for k, v in dict_to_save.items() if v is not None}
    append_record(dict_to_save, full_log_name)


def checkpoint_name(title, checkpoint_dir):
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt__" + str(title))
    return checkpoint_prefix
