import os
import pickle
import argparse

import numpy as np
import torch

from dataset import get_data_name, generate_data
from train import train
from plot import plot


SEED = 0
SAVE_DIR = "results"


def main(
    input_d,
    client_samples,
    num_clients,
    rounds,
    interval,
    lr=0.1,
    bias=False,
    two_stage=None,
    data_type="gaussian",
    local_var=1.0,
    mu_het=1.0,
    label_het=0.9,
    results_name=None,
    quiet=False,
):

    # Set random seed.
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # Generate data, if it hasn't yet been generated.
    x_name, y_name = get_data_name(
        num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type
    )
    data_exists = int(os.path.isfile(x_name)) + int(os.path.isfile(y_name))
    if data_exists == 0:
        generate_data(num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type)
    elif data_exists == 1:
        raise ValueError(
            f"Incomplete dataset in files {x_name} and {y_name}. Restore or delete it before rerunning."
        )

    # Train.
    results = train(
        input_d,
        client_samples,
        num_clients,
        rounds,
        interval,
        lr,
        bias,
        two_stage,
        data_type,
        local_var,
        mu_het,
        label_het,
        quiet
    )

    # Save and plot results.
    if results_name is not None:
        results_dir = os.path.join(SAVE_DIR, results_name)
        if not os.path.isdir(results_dir):
            os.makedirs(results_dir)
        with open(os.path.join(results_dir, "results.pkl"), "wb") as f:
            pickle.dump(results, f)
        plot(results, results_dir)

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_d", type=int, default=3)
    parser.add_argument("--client_samples", type=int, default=10)
    parser.add_argument("--num_clients", type=int, default=8)
    parser.add_argument("--rounds", type=int, default=1024)
    parser.add_argument("--interval", type=int, default=8)
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--bias", type=bool, default=False, action="store_true")
    parser.add_argument("--two_stage", type=str, default=None)
    parser.add_argument("--data_type", choices=["gaussian", "easy", "weird", "hard", "mnist"], default="gaussian")
    parser.add_argument("--local_var", type=float, default=1.0)
    parser.add_argument("--mu_het", type=float, default=1.0)
    parser.add_argument("--label_het", type=float, default=0.9)
    parser.add_argument("--results_name", type=str, default=None)
    args = parser.parse_args()
    main(**vars(args))
