import yaml
import json
import argparse
import logging
import pickle
from datetime import datetime
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

from src.tasks import TASKS
from src.dataset import create_client_datasets
from src.compressors import COMPRESSORS
from src.utils import seed_all

## Set logging level to lowest
logging.basicConfig(level=logging.NOTSET)

CONFIG_PATH=Path("configs")
RESULTS_PATH=Path("results")
TENSORBOARD_PATH=Path("tensorboard")


parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, required=True, help="Task", choices=["linreg", "power_iter", "kmeans", "fedavg"])
parser.add_argument("--dataset", type=str, required=True, help="Dataset", choices=["synthetic", "ujindoorloc", "mnist", "har", "fashionmnist", "femnist", "cifar10", "cifar100"])
parser.add_argument("--compressor", type=str, required=True, help="Compressor", choices=list(COMPRESSORS.keys()))
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--flag", type=int, default=0, help="Additional flag to distinguish training runs")

args = parser.parse_args()

start_time = datetime.now()
start_time_string =  start_time.strftime("%Y-%m-%d_%H:%M:%S")
logging.info(f"Task : {args.task}, Dataset : {args.dataset}, Compressor: {args.compressor}, Seed : {args.seed}, Start Time : {start_time_string}")

## Make results dir
results_dir = RESULTS_PATH / Path(f"{args.task}/{args.dataset}/seed_{args.seed}/{args.compressor}/start_time_{start_time_string}")
results_dir.mkdir(parents=True, exist_ok=True)

## Make experiment name for tensorboard
experiment_name = f"task_{args.task}_dataset_{args.dataset}_compressor_{args.compressor}_seed_{args.seed}_flag_{args.flag}_start_time_{start_time_string}"
tb_path = TENSORBOARD_PATH / Path(experiment_name)

## Seed experiments
seed_all(args.seed)

def load_configs(args):
    
    ## Load default task/dataset/compressor configs.
    task_config = {"name": args.task}
    with open(CONFIG_PATH / Path(f"tasks/{args.task}/base.yaml")) as f:
        task_config["params"] = yaml.safe_load(f)

    dataset_config ={"name" : args.dataset}
    with open(CONFIG_PATH / Path(f"datasets/{args.dataset}.yaml")) as f:
        dataset_config["params"] = yaml.safe_load(f)

    compressor_config = {"name": args.compressor}
    with open(CONFIG_PATH / Path(f"compressors.yaml")) as f:
        all_compressors_config = yaml.safe_load(f)
        compressor_config["params"] = all_compressors_config[args.compressor]

    ## Update config if a task config file exists for given dataset.
    dataset_task_config_path = CONFIG_PATH / Path(f"tasks/{args.task}/{args.dataset}.yaml")
    if dataset_task_config_path.exists():
        with open(dataset_task_config_path, "r") as f:
            dataset_task_config = yaml.safe_load(f)
            # Update compressor params
            if "compressor" in dataset_task_config.keys():
                if args.compressor in dataset_task_config["compressor"].keys():
                    if dataset_task_config["compressor"][args.compressor] is not None:
                        if compressor_config["params"] is not None:
                            compressor_config["params"] = {**compressor_config["params"], **dataset_task_config["compressor"][args.compressor]}
                        else:
                            compressor_config["params"] = dataset_task_config["compressor"][args.compressor]
        ## Update params of dataset and task config.
        if "dataset" in dataset_task_config.keys():
            dataset_config["params"] = {**dataset_config["params"],**dataset_task_config["dataset"]}
        if "task" in dataset_task_config.keys():
            task_config["params"] = {**task_config["params"], **dataset_task_config["task"]}

    config = {"task": task_config, "dataset": dataset_config, "compressor": compressor_config}
    return config




# Get config
config=load_configs(args)
logging.info("Config loaded")

# Create client datasets
all_client_data = create_client_datasets(task = config["task"]["name"], config=config["dataset"])
logging.info("Generated all clients")

## Initialize compressor
# This handles the case of empty kwargs
compressor_kwargs = {"m": config["dataset"]["params"]["m"], "d": config["dataset"]["params"]["d"]} 
if config["compressor"]["params"] is not None:
    if len(config["compressor"]["params"]) > 0:
        compressor_kwargs = {**compressor_kwargs, **config["compressor"]["params"]}
if args.compressor in ["sparsereg", "hadamard"]:
    compressor_kwargs["seed"] = args.seed
compressor = COMPRESSORS[config["compressor"]["name"]](**compressor_kwargs)
logging.info("Compressor initialized")

## Create summary writer
tb_writer = SummaryWriter(log_dir=tb_path)

## Initialize task server
server = TASKS[config["task"]["name"]](d = config["dataset"]["params"]["d"],
                                                all_client_data=all_client_data,
                                                 compressor=compressor, 
                                                 results_dir=results_dir,
                                                 tb_writer=tb_writer,
                                                 **config["task"]["params"])


logging.info("Starting task")

## Run task 
results, model = server.run()

logging.info("Task completed")

## Save config, results, and timing.
with open(results_dir / Path("config.json"), "w") as f:
    json.dump(config, f)
logging.info("Config saved")
results["run_time"] = str(datetime.now() - start_time)
with open(results_dir/ Path('final_results.json'), "w") as f:
    json.dump(results, f)
logging.info("Results saved")
with open(results_dir/ Path('final_model.pickle'), 'wb') as f:
    pickle.dump(model, f)
logging.info("Model saved")