import argparse
import logging
import os
import random
import sys
import datetime
import numpy as np
import torch
import json

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from data_loader import load_partition_data, load_data1
from model.resnet import FRJVE_Model
from frjve_inversion import FRJVE
from trainer.my_model_trainer_inver_frjve import MyModelTrainer as MyModelTrainerCLS
from opacus.grad_sample import GradSampleModule

import warnings

warnings.filterwarnings('ignore')


def prepare(config_path):
    parser = argparse.ArgumentParser()

    # Training settings
    parser.add_argument('--model', type=str, default='MFBaseModel', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='movielens100k', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='../../data',
                        help='data directory')

    parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                        help='input batch size for training')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.0001)

    parser.add_argument('--client_num_in_total', type=int, default=5, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--client_num_per_round', type=int, default=3, metavar='NN',
                        help='number of workers')

    parser.add_argument('--users_per_client', type=int, default=100, metavar='NN',
                        help='number of workers')

    parser.add_argument('--items_per_client', type=int, default=200, metavar='NN',
                        help='number of workers')

    parser.add_argument('--hidden_dim', type=int, default=10, metavar='NN',
                        help='number of workers')

    parser.add_argument('--baseline', default="FRJVE",
                        help='Training model')

    parser.add_argument('--comm_round', type=int, default=200,
                        help='how many round of communications we should use')

    parser.add_argument("--alpha", help="dirichlet", type=float, default=10)

    parser.add_argument('--gpu', type=int, default=3,
                        help='gpu')

    parser.add_argument('--beta', default=1)
    parser.add_argument('--personal_learning_rate', default=0.09)
    parser.add_argument('--lamda', default=15)
    parser.add_argument('--learning_rate', default=0.005)
    parser.add_argument('--task', default='ml_100k', help='ml_100k: movielens100k'
                                                    'ml_1m: movielens1m')
    parser.add_argument('--base_model', default='MF')
    parser.add_argument('--seed', type=int, default=2020)
    parser.add_argument('--ratio', default=[0.8, 0.2], help='[train_ratio, test_ratio]')
    parser.add_argument('--epoch', type=int, default=5)
    parser.add_argument('--lr_prototype', type=float, default=0.01, help='lr for prototype bridge function')
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    with open(config_path, 'r') as f:
        config = json.load(f)
        config['base_model'] = args.base_model
        config['task'] = args.task
        config['ratio'] = args.ratio
        config['epoch'] = args.epoch
        config['lr'] = args.lr
        config['lr_prototype'] = args.lr_prototype
        config['client_num'] = args.client_num_in_total
        config['dataset'] = args.dataset
        config['personal_learning_rate'] = args.personal_learning_rate
        config['lamda'] = args.lamda
        config['learning_rate'] = args.learning_rate
    return args, config


def load_data(args, dataset_name):
    logging.info("load_data. dataset_name = %s" % dataset_name)
    rating_counts, co_uid, client_data_all, train_data_num, test_data_num, train_data_global, test_data_global, \
    train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
    class_num = load_data1(args.batch_size,args.client_num_in_total,
                                    args.users_per_client,args.items_per_client,
                                    args.alpha ,args.data_dir, dataset_name)

    dataset = [client_data_all, train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]

    return dataset

def create_model(args, dataset_name, model_name, output_dim, hidden_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    model = None
    if model_name == "MFBaseModel":
        if dataset_name == 'ml_100k':
            model = FRJVE_Model(field_dims_src={"uid_src": 943, "iid_src": 1682},
                                  field_dims_tgt={"uid_tgt": 943, "iid_tgt": 1682},
                                  num_fields=5, emb_dim=10, topk=400)
        elif dataset_name == 'ml_1m':
            model = FRJVE_Model(field_dims_src={"uid_src": 6040, "iid_src": 3952},
                                  field_dims_tgt={"uid_tgt": 6040, "iid_tgt": 3952},
                                  num_fields=5, emb_dim=10, topk=3000)
        # model = GradSampleModule(model)
    return model

if __name__ == "__main__":
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    config_path = 'config.json'
    args, config = prepare(config_path)

    logger.info(args)
    device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    logger.info(device)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True

    # load data
    dataset = load_data(args, args.dataset)
    model = create_model(args, dataset_name=args.task, model_name=args.model, output_dim=dataset[8], hidden_dim=args.hidden_dim)
    model_trainer = MyModelTrainerCLS(model, args, config)

    if args.baseline == "FRJVE":
        FRJVE_API = FRJVE(dataset, device, args, model_trainer)
        FRJVE_API.train()
