import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

import os
import json
import math
import torch
import numpy
import argparse
from scipy.io import arff
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import scikit_wrappers_0 as scikit_wrappers
import pickle
from sklearn.model_selection import train_test_split
from sklearn.cluster import kmeans_plusplus
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import rand_score
from Common_functions import *
from sklearn.metrics import hinge_loss

import argparse
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('true'):
        return True
    elif v.lower() in ('false'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def main(args):

    client_id = args.client_id;

    gpu = args.gpu
    in_channel = args.in_channel;
    cuda = args.cuda
    save_path = args.save_path
    encoder_weights_save_dir = args.save_path_encoder
    batch_size= args.batch_size
    batching = args.batching
    C=args.C
    num_labeled_examples_available=args.num_labeled_examples_available
    perform_baseline_exp=args.perform_baseline_exp
    num_clusters=args.num_clusters
    t=args.t;
    num_clusters_for_overlap=args.num_clusters_for_overlap;

    with open('train_labeled_x', 'rb') as fp:
        train = pickle.load(fp)
    with open('train_labeled_y', 'rb') as fp:
        train_labels = pickle.load(fp)

    with open('token_x', 'rb') as fp:
        tokenset = pickle.load(fp)
    with open('token_y', 'rb') as fp:
        tokenset_labels = pickle.load(fp)

    local_train = train[client_id][:num_labeled_examples_available];
    local_train_labels = train_labels[client_id][:num_labeled_examples_available];
    del train, train_labels
    hyper = "default_hyperparameters.json"

    hf = open(os.path.join(hyper), 'r')
    params = json.load(hf)
    hf.close()
    # Check the number of input channels
    params['in_channels'] = in_channel
    params['cuda'] = cuda
    params['gpu'] = gpu

    local_model = scikit_wrappers.CausalCNNEncoderClassifier(**params)

    with open(encoder_weights_save_dir, 'rb') as fp:
        encoder_i_weight = pickle.load(fp)

    local_model.encoder.load_state_dict(encoder_i_weight)


    if batching:
        idx_list_batching = numpy.random.choice(len(local_train), batch_size, replace=False)
        local_train = local_train[idx_list_batching]
        local_train_labels = local_train_labels[idx_list_batching]

    for k in range(10):
        if k not in local_train_labels:
            local_train = numpy.concatenate((local_train, numpy.expand_dims(tokenset[k, :, :], axis=0)), axis=0)
            local_train_labels = numpy.concatenate((local_train_labels, numpy.expand_dims(tokenset_labels[k], axis=0)),
                                                   axis=0)

    features = local_model.encode(local_train)

    local_model.classifier = fit_classifier_hyperparameters(local_model, features, local_train_labels, C=C)
    torch.save(local_model, save_path + 'SVM_output_layer_Model_for_client_' + str(client_id) + '.pt')


    if perform_baseline_exp and t>0:


        with open(save_path + 'IFCA_SVM_Aggregated_Weights' + str(t-1) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict=pickle.load(fp)

        with open(save_path + 'IFCA_SVM_Aggregated_Intercepts' + str(t-1) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict_intercept=pickle.load(fp)

        loss_opt = float('inf');
        for cluster_num in range(num_clusters):
            classifier=copy.deepcopy(local_model.classifier)
            classifier.coef_ = cluster_aggs_main_dict[cluster_num]
            classifier.intercept_ = cluster_aggs_main_dict_intercept[cluster_num]

            loss = 0
            for j in range(len(local_train)):
                loss += hinge_loss([local_train_labels[j]],
                                   classifier.decision_function(numpy.expand_dims(features[j], 0)),
                                   labels=numpy.arange(10)) ** 2
            if loss < loss_opt:
                opt_cluster = copy.copy(cluster_num)
                loss_opt = copy.copy(loss)

        with open(save_path + 'IFCA_cluster_choice_'+str(client_id)+'and_Round_' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(opt_cluster, fp)


        with open(save_path + 'FLSC_SVM_Aggregated_Weights' + str(t-1) + '_bs_' + str(
                    batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict=pickle.load(fp)

        with open(save_path + 'FLSC_SVM_Aggregated_Intercepts' + str(t-1) + '_bs_' + str(
                batch_size) + '.pkl', 'rb') as fp:
            cluster_aggs_main_dict_intercept=pickle.load(fp)

        C_t_vec = [];
        for cluster_num in range(num_clusters_for_overlap):
            loss_opt = float('inf')
            for k in range(num_clusters):
                if k in C_t_vec:
                    continue
                classifier=copy.deepcopy(local_model.classifier)
                classifier.coef_ = cluster_aggs_main_dict[k]
                classifier.intercept_ = cluster_aggs_main_dict_intercept[k]
                loss = 0
                for j in range(len(local_train)):
                    loss += hinge_loss([local_train_labels[j]],
                                       classifier.decision_function(numpy.expand_dims(features[j], 0)),
                                       labels=numpy.arange(10)) ** 2
                if loss < loss_opt:
                    opt_cluster = copy.copy(k)
                    loss_opt = copy.copy(loss)
            C_t_vec += [opt_cluster]

        with open(save_path + 'FLSC_cluster_choice_'+str(client_id)+'and_Round_' + str(t) + '_bs_' + str(
                batch_size) + '.pkl', 'wb') as fp:
            pickle.dump(C_t_vec, fp)






if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_path', type=str, default='Save_models/')
    parser.add_argument('--in_channel', type=int, default=3)
    parser.add_argument('--cuda', type=str2bool, default=True)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--batching', type=str2bool, default=True)
    parser.add_argument('--client_id', type=int)
    parser.add_argument('--gpu', type=int)
    parser.add_argument('--save_path_encoder', type=str)
    parser.add_argument('--C', type=float, default=float('inf'))
    parser.add_argument('--num_labeled_examples_available', type=int, default=1944)
    parser.add_argument('--perform_baseline_exp', type=str2bool, default=True)
    parser.add_argument('--num_clusters', type=int, default=3)
    parser.add_argument('--t', type=int)
    parser.add_argument('--num_clusters_for_overlap', type=int)


    args = parser.parse_args()
    main(args)