import os
import json
import math
import torch
import numpy
import argparse
from scipy.io import arff
# import weka.core.jvm
# import weka.core.converters
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import scikit_wrappers_0 as scikit_wrappers
import pickle
from Common_functions import fit_encoder_hyperparameters
from sklearn.model_selection import train_test_split
from sklearn.cluster import kmeans_plusplus
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import rand_score
from sklearn.metrics import hinge_loss
import argparse

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('true'):
        return True
    elif v.lower() in ('false'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')



def main(args):
    client_id = args.client_id;
    gpu = args.gpu
    encoder_runs = args.encoder_runs;
    in_channel = args.in_channel;
    cuda = args.cuda
    save_path = args.save_path
    encoder_weights_save_dir = args.save_path_encoder

    print(['Client_id:', str(client_id), ' and gpu:',  str(gpu), 'and runs', str(encoder_runs)])

    with open('train_encoder_x', 'rb') as fp:
        train = pickle.load(fp)

    with open('train_encoder_y', 'rb') as fp:
        train_labels = pickle.load(fp)

    local_train=train[client_id];
    local_train_labels=train_labels[client_id];
    del train, train_labels
    hyper = "default_hyperparameters.json"

    hf = open(os.path.join(hyper), 'r')
    params = json.load(hf)
    hf.close()
    # Check the number of input channels
    params['in_channels'] = in_channel
    params['cuda'] = cuda
    params['gpu'] = gpu


    local_model = scikit_wrappers.CausalCNNEncoderClassifier(**params)

    # local_model.set_params(**params)
    local_model.encoder.train()

    if encoder_runs==0:
        encoder_i_weight = None
    else:

        with open(encoder_weights_save_dir, 'rb') as fp:
            encoder_i_weight=pickle.load(fp)

    local_model.encoder = fit_encoder_hyperparameters(local_model, local_train, encoder_i_weight, cuda,gpu)
    torch.save(local_model, save_path + '_Model_for_client_'+str(client_id)+'.pt')



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--in_channel', type=int, default=3)
    parser.add_argument('--cuda', type=str2bool, default=True)
    parser.add_argument('--client_id', type=int)
    parser.add_argument('--gpu', type=int)
    parser.add_argument('--encoder_runs', type=int)
    parser.add_argument('--save_path_encoder', type=str)



    args = parser.parse_args()
    main(args)