from prepare_hetero_datasets import prepare_hetero_datasets
from utils.load_dataset import load_dataset
from predict import predict
from optimizers import LocalSGD_Server
from utils.get_model import create_model
import numpy as np
import torch
import sys
import pickle
from itertools import accumulate
import time
import argparse
from tqdm import tqdm
import os
from csv import DictWriter
os.environ["OMP_NUM_THREADS"] = "1"
sys.path.append('./')


def train(save_path, model_path, **kargs):
    train_data, test_data, n_classes, n_channels = load_dataset(
        kargs["dataset_name"])

    train_loaders, pred_loader_on_train_data, pred_loader_on_test_data \
        = prepare_hetero_datasets(train_data, test_data, n_classes,
                                  n_workers=kargs["n_workers"],
                                  q=kargs["q"],
                                  train_batch_size=kargs["train_batch_size"],
                                  pred_batch_size=kargs["pred_batch_size"])

    model = create_model(kargs["model_name"], n_classes, n_channels)
    if model_path:
        model.load_state_dict(torch.load(model_path))
    model.train()

    optimizer = LocalSGD_Server(model=model,
                                eta=kargs["eta"],
                                weight_decay=kargs["weight_decay"],
                                train_loaders=train_loaders,
                                gpu_id=kargs["gpu_id"],
                                n_local_iters=kargs["n_local_iters"],
                                n_workers=kargs["n_workers"],
                                epsilon=kargs['epsilon'])

    saved_info = {"train_loss": [], "train_acc": [], "train_grad_norm": [],
                  "test_loss": [], "test_acc": [], "test_grad_norm": [], 
                  "worker_loss": [], "worker_acc": [], "defect_workers": [], "args": kargs}

    init_train_loss, _, _ = predict(optimizer.get_model(),
                                    pred_loader_on_train_data,
                                    kargs["weight_decay"])

    s = time.time()
    for i in tqdm(range(kargs["n_global_iters"])):
        optimizer.update()
        if (i+1) % kargs["save_intvl"] == 0:
            update_time_per_worker = (time.time() - s)/kargs["n_workers"]
            #print("train loss"*10)
            train_loss, train_acc, train_grad_norm = predict(optimizer.get_model(),
                                                             pred_loader_on_train_data,
                                                             kargs["weight_decay"])
            #print("test loss"*10)
            test_loss,  test_acc,  test_grad_norm = predict(optimizer.get_model(),
                                                            pred_loader_on_test_data,
                                                            0.0)
            saved_info["train_loss"].append(train_loss)
            saved_info["train_acc"].append(train_acc)
            saved_info["train_grad_norm"].append(train_grad_norm)
            saved_info["test_loss"].append(test_loss)
            saved_info["test_acc"].append(test_acc)
            saved_info["test_grad_norm"].append(train_grad_norm)
            for data_loader in train_loaders:
                loss, acc, _ = predict(optimizer.get_model(),
                                    data_loader,
                                    kargs["weight_decay"])
                saved_info["worker_loss"].append(loss)
                saved_info["worker_acc"].append(acc)
            saved_info["defect_workers"] = list(optimizer._defect_workers)
            model.train()
            # if train_loss > 2 * init_train_loss:
            #     print("Learning was stopped")
            #     break
            print("Iter: {} | Train Loss: {}, Train Acc: {}, Train Grad Norm {}, | Test Loss: {}, Test Acc: {}, Test Grad Norm: {}, Elapsed Time: {}"
                  .format(i+1, train_loss, train_acc, train_grad_norm, test_loss, test_acc, test_grad_norm, time.time() - s))
            #print("time spent for 10 iters: {}".format(time.time() - s))
            s = time.time()
            # print
        with open(save_path, "wb") as f:
            pickle.dump(saved_info, f)
    field_names = ["q", "k", "eta", "train_accuracy"]
    path = f"results/fine_tune_{kargs['q']}.csv"
    with open (path, 'a') as f:
        logger = DictWriter(f, fieldnames=field_names)
        logger.writerow({field_names[0]: kargs["q"],
                     field_names[1]: kargs["n_local_iters"],
                     field_names[2]: kargs["eta"],
                     field_names[3]: saved_info["train_acc"][-1]})
    print(optimizer._defect_workers)

    logging_dir = "logger/logging_1000_global_iters/"
    if not os.path.exists(logging_dir):
        os.makedirs(logging_dir)

    if not os.path.isfile(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl"):
        with open(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl", "ab") as f:
            pickle.dump(saved_info, f)
        logger = open(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.txt", "a")
        logger.write(str(saved_info))
        logger.close()
    else:
        os.remove(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl")
        with open(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.pkl", "ab") as f:
            pickle.dump(saved_info, f)
        logger = open(f"{logging_dir}q_{kargs['q']}_n_local_iters_{kargs['n_local_iters']}_eta_{kargs['eta']}_epsilon_{kargs['epsilon']}.txt", "a")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Parameters
    parser.add_argument("--eta", type=float, default=0.0002) #keep record the largest eta that monotenically increase the accuracy, the other one that can have the largest accuracy without defection, the one that has largest accuracy even if it has defection
    # General experimental info
    parser.add_argument("--n_workers", type=int, default=2)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dataset_name", type=str, default='cifar10')
    parser.add_argument("--model_name", type=str, default='linear')
    parser.add_argument("--exp_name", type=str, default='test')
    parser.add_argument("--epsilon", type=float, default=0.5)
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--n_global_iters", type=int, default=1000)
    parser.add_argument("--n_local_iters", type=int, default=1)
    parser.add_argument("--train_batch_size", type=int, default=256)
    # homogeneity parameter (refer to the paper)
    parser.add_argument("--q", type=float, default=0.9)
    # Fixed args
    parser.add_argument("--use_pretrain_model", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--pred_batch_size", type=int, default=1024)
    parser.add_argument("--save_intvl", type=int, default=1)
    args = parser.parse_args()

    save_dir = os.path.join("results", args.exp_name, args.model_name, args.dataset_name, "homogeneity="+str(args.q),
                            "lsgd" + "_K="+str(args.n_local_iters) + "_b="+str(args.train_batch_size), "_eta="+str(args.eta), "_epsilon="+str(args.epsilon))
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seed="+str(args.seed)+".pickle")
    model_path = None
    if args.use_pretrain_model:
        model_path = os.path.join(
            "saved_models", "pretrain", args.model_name, args.dataset_name, "model.pth")

    if torch.cuda.is_available():
        torch.manual_seed(args.gpu_id)
        torch.cuda.manual_seed(args.seed)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
        torch.cuda.set_device(args.gpu_id)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print("GPU Enabled")
    else:
        print("GPU Not Enabled")
    np.random.seed(args.seed)
    train(save_path=save_path, model_path=model_path, **vars(args))