from prepare_hetero_datasets import prepare_hetero_datasets
from utils.load_dataset import load_dataset
from predict import predict
from optimizers import LocalSGD_Server
from utils.get_model import create_model
import numpy as np
import torch
import sys
import pickle
from itertools import accumulate
import time
import argparse
import os
from csv import DictWriter
os.environ["OMP_NUM_THREADS"] = "1"
sys.path.append('./')


def train(save_path, model_path, **kargs):
    train_data, test_data, n_classes, n_channels = load_dataset(
        kargs["dataset_name"])

    train_loaders, pred_loader_on_train_data, pred_loader_on_test_data \
        = prepare_hetero_datasets(train_data, test_data, n_classes,
                                  n_workers=kargs["n_workers"],
                                  q=kargs["q"],
                                  train_batch_size=kargs["train_batch_size"],
                                  pred_batch_size=kargs["pred_batch_size"])

    model = create_model(kargs["model_name"], n_classes, n_channels)
    if model_path:
        model.load_state_dict(torch.load(model_path))
    model.train()

    optimizer = LocalSGD_Server(model=model,
                                eta=kargs["eta"],
                                weight_decay=kargs["weight_decay"],
                                train_loaders=train_loaders,
                                gpu_id=kargs["gpu_id"],
                                n_local_iters=kargs["n_local_iters"],
                                n_workers=kargs["n_workers"],
                                epsilon=kargs['epsilon'])

    saved_info = {"train_loss": [], "train_acc": [], "train_grad_norm": [],
                  "test_loss": [], "test_acc": [], "test_grad_norm": [], 
                  "worker_loss": [], "worker_acc": [], "defect_workers": [], "args": kargs}

    # init_train_loss, _, _ = predict(optimizer.get_model(),
    #                                 pred_loader_on_train_data,
    #                                 kargs["weight_decay"])

    s = time.time()
    for i in range(kargs["n_global_iters"]):
        optimizer.update()
        if len(optimizer._defect_workers) == optimizer._n_workers:
            print("All workers are defect")
            break
    
    train_loss, train_acc, train_grad_norm = predict(optimizer.get_model(),
                                                             pred_loader_on_train_data,
                                                             kargs["weight_decay"])
    test_loss,  test_acc,  test_grad_norm = predict(optimizer.get_model(),
                                                    pred_loader_on_test_data,
                                                    0.0)
    saved_info["train_loss"].append(train_loss)
    saved_info["train_acc"].append(train_acc)
    saved_info["train_grad_norm"].append(train_grad_norm)
    saved_info["test_loss"].append(test_loss)
    saved_info["test_acc"].append(test_acc)
    saved_info["test_grad_norm"].append(test_grad_norm)
    for data_loader in train_loaders:
        loss, acc, _ = predict(optimizer.get_model(),
                               data_loader,
                               kargs["weight_decay"])
        saved_info["worker_loss"].append(loss)
        saved_info["worker_acc"].append(acc)
    saved_info["defect_workers"] = list(optimizer._defect_workers)
    
    logging_dir = "results/logging"
    if not os.path.exists(logging_dir):
        os.makedirs(logging_dir)

    if not os.path.isfile(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl"):
        with open(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl", "ab") as f:
            pickle.dump(saved_info, f)
        logger = open(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.txt", "a")
        logger.write(str(saved_info))
        logger.close()
    else:
        with open(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl", "rb") as f:
            old_data = pickle.load(f)
        new_data = {"train_loss": [], "train_acc": [], "train_grad_norm": [],
                  "test_loss": [], "test_acc": [], "test_grad_norm": [], 
                  "worker_loss": [], "worker_acc": [], "defect_workers": [], "args": kargs}
        for key in new_data.keys():
            if key == "args" or key == "defect_workers":
                new_data[key] = old_data[key]
            elif key == "worker_loss" or key == "worker_acc":
                new_data[key] = np.stack([np.array(old_data[key]), np.array(saved_info[key])])
            else: 
                new_data[key] = old_data[key] + saved_info[key]
        with open(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl", "wb") as f:
            pickle.dump(new_data, f)
        logger = open(f"results/logging/q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.txt", "w")
        logger.write(str(new_data))
        logger.close()

    # example usage of pickle
    # with open("results/logging/q_0.1_n_local_iters_1_eta_0.1.pkl", "rb") as f:
    #     saved_info = pickle.load(f)
    # print(saved_info["train_loss"])
    # print(saved_info["train_acc"])
    # print(saved_info["train_grad_norm"])
    # ...


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Parameters
    parser.add_argument("--eta", type=float, default=0.1)
    # General experimental info
    parser.add_argument("--n_workers", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dataset_name", type=str, default='cifar10')
    parser.add_argument("--model_name", type=str, default='fc')
    parser.add_argument("--exp_name", type=str, default='test')
    parser.add_argument("--epsilon", type=float, default=2.0)
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--n_global_iters", type=int, default=100)
    parser.add_argument("--n_local_iters", type=int, default=1)
    parser.add_argument("--train_batch_size", type=int, default=256)
    # homogeneity parameter (refer to the paper)
    parser.add_argument("--q", type=float, default=0.0)
    # Fixed args
    parser.add_argument("--use_pretrain_model", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--pred_batch_size", type=int, default=1024)
    parser.add_argument("--save_intvl", type=int, default=10)
    args = parser.parse_args()

    save_dir = os.path.join("results", args.exp_name, args.model_name, args.dataset_name, "homogeneity="+str(args.q),
                            "lsgd" + "_K="+str(args.n_local_iters) + "_b="+str(args.train_batch_size), "_eta="+str(args.eta), "_epsilon="+str(args.epsilon))
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seed="+str(args.seed)+".pickle")
    model_path = None
    if args.use_pretrain_model:
        model_path = os.path.join(
            "saved_models", "pretrain", args.model_name, args.dataset_name, "model.pth")

    if torch.cuda.is_available():
        torch.manual_seed(args.gpu_id)
        torch.cuda.manual_seed(args.seed)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
        torch.cuda.set_device(args.gpu_id)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print("GPU Enabled")
    else:
        print("GPU Not Enabled")
    np.random.seed(args.seed)
    train(save_path=save_path, model_path=model_path, **vars(args))
