import os
import pdb
import json
import pickle
import pprint
import random
import codecs
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import sys
from datetime import datetime
from data.mixture_loader.mixture import get

pp = pprint.PrettyPrinter(indent=4)

CIFAR10      = 0
CIFAR100     = 1
MNIST        = 2
SVHN         = 3
FMNIST       = 4
TRAFFICSIGN  = 5
FACESCRUB    = 6
NMNIST       = 7

CNN          = 0
APD          = 1
ABC_L2T      = 2
ABC_APD      = 3
EWC          = 4

TASK1        = 1
CURRENT_TASK = 2
CUMULATIVE   = 4

def get_setting(opt):
    setting = ''
    if opt.model == CNN:
        setting += '-cnn'
    elif opt.model == APD:
        setting += '-apd'
    if opt.federated:
        setting += '-fcl'
    else:
        setting += '-cl'
    return setting

def get_dataset_name(did):
    if did == CIFAR10:
        return 'cifar10'
    elif did == CIFAR100:
        return 'cifar100'
    elif did == MNIST:
        return 'mnist'
    elif did == SVHN:
        return 'svhn'
    elif did == FMNIST:
        return 'fashion_mnist'
    elif did == TRAFFICSIGN:
        return 'traffic_sign'
    elif did == FACESCRUB:
        return 'face_scrub'
    elif did == NMNIST:
        return 'not_mnist'

def get_test_mode(value, test_mode):
    if test_mode=='task1':
        return (value == TASK1 or value == TASK1+CURRENT_TASK or value == TASK1+CURRENT_TASK+CUMULATIVE)
    if test_mode=='current':
        return (value == CURRENT_TASK or value == TASK1+CURRENT_TASK or value == TASK1+CURRENT_TASK+CUMULATIVE)
    if test_mode=='cumulative':
        return (value == CUMULATIVE or value == TASK1+CURRENT_TASK+CUMULATIVE)

def get_model(value, model):
    if model=='cnn':
        return value==CNN
    elif model=='apd':
        return value==APD
    elif model=='abc_l2t':
        return value==ABC_L2T
    elif model=='abc_apd':
        return value==ABC_APD
    elif model=='ewc':
        return value==EWC
    else:
        SystemExit('SystemExit: no proper model was given. see help: -h')

def syslog(pid, message):
    if pid == -2:
        worker = 'server'
    elif pid == -1:
        worker = 'data'
    else:
        worker = 'client:'+str(pid)
    print('[%s][%s] %s' %(datetime.now().strftime("%Y/%m/%d-%H:%M:%S"), worker, message))

def random_shuffle(seed, l):
    random.seed(seed)
    random.shuffle(l)

def random_sample(seed, l, num_pick):
    random.seed(seed)
    return random.sample(l, num_pick)

def obj_to_pickle_string(x):
    return codecs.encode(pickle.dumps(x), "base64").decode()

def pickle_string_to_obj(s):
    return pickle.loads(codecs.decode(s.encode(), "base64"))

def get_serialized_weights(weights):
    return json.dumps([w.tolist() for w in weights])

def get_weights_from_string(str_weights):
    weights = json.loads(str_weights)
    return [np.array(w) for w in weights]

def write_file(filepath, filename, data):
    if os.path.isdir(filepath) == False:
        os.makedirs(filepath)
    with open(os.path.join(filepath, filename), 'w+') as outfile:
        json.dump(data, outfile)

def np_save(base_dir, filename, data):
    if os.path.isdir(base_dir) == False:
        os.makedirs(base_dir)
    np.save(os.path.join(base_dir, filename), data)

def np_load(path):
    return np.load(path, allow_pickle=True)

def pickle_save(base_dir, filename, data):
    with open(os.path.join(base_dir, filename), 'wb') as out:
        pickle.dump(data, out, protocol=pickle.HIGHEST_PROTOCOL)

def pickle_load(path):
    with open(path, 'rb') as out:
        loaded = pickle.load(out)
    return loaded

def save_task(base_dir, filename, data):
    np_save(base_dir, filename, data)
    # pickle_save(base_dir, filename, data)

def load_task(base_dir, task):
    return np_load(os.path.join(base_dir, task))

def save_weights(base_dir, filename, weights):
    np_save(base_dir, filename, weights)

def load_weights(path):
    return np.load(path, allow_pickle=True)

def compare_weights(w1, w2):
    a = np.abs(np.subtract(w1, w2))
    b = np.ravel(a[0])
    c = np.ravel(a[1])
    d = np.sum(b) + np.sum(c)
    if d == 0 :
        print('weights are equal')
    else:
        print('weights are not equal')

def _visualizae(x_data, y_data, info):

    matplotlib.rcParams.update({'font.size': 16})
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(info['size'][0], info['size'][1])
    axes = plt.gca()
    axes.set_ylim([0,1.00])

    plt.grid(True)
    plt.xticks(np.arange(min(x_data[0])-1, max(x_data[0])+1, info['xtick']), rotation=45)
    plt.yticks(np.arange(0.0, 1.0, info['ytick']))

    plt.title(info['title'], fontsize=20)
    plt.xlabel(info['x_axis_label'], fontsize=20)
    plt.ylabel(info['y_axis_label'], fontsize=20)

    for i, y in enumerate(y_data):
        plt.plot(x_data[i], y, info['color'][i], marker=info['marker'][i], markersize=info['markersize'],label=info['legends'][i])

    plt.legend()

def visualize(info):
    avg_perfs = {}
    settings = {}
    results = []
    classes = {}
    for i, p in enumerate(info['path']):
        with open(p) as f:
            result=[]
            file =json.loads(f.read())
            settings[info['models'][i]]  = file['setting']
            if 'class_info' in file['setting']:
                classes[info['models'][i]]  = file['setting']['class_info']
                del settings[info['models'][i]]['class_info']
            if info['cumulative']:
                avg_perfs[info['models'][i]] = file['test_accuracy_cumulative'][0][2]
            if info['task_0']:
                result.append(file['test_accuracy_task1'])
            if info['current']:
                result.append(file['test_accuracy'])
            results.append(result)

    if info['setting']:
        pp.pprint(settings)
    if info['cumulative']:
        pp.pprint(avg_perfs)
    if info['task_details']:
        print()
        for c in classes:
            print(classes[c])
            print()

    y_data = []
    for i, result in enumerate(results):
        for j, r in enumerate(result):
            r = np.array(r)
            y_data.append(r[:, 2])
    x_data = []
    for i, y in enumerate(y_data):
        x_data.append(range(1, len(y)+1))

    _visualizae(x_data, y_data, info)

def get_memory_ratio(info):
    model_capa = {}
    for model, model_info in info.items():
        clients_capa = []
        for client in model_info['clients']:
            with open(client) as f:
                # client_capa = []
                client =json.loads(f.read())
                capa = client['capacity'][-1]
                if model_info['full-base'] == False:
                    capa /= (5*5*3*20 + 5*5*20*50 + 3200*800 + 800*500 + 500*5 *(10))
                clients_capa.append(capa)
        model_capa[model] = (np.mean(clients_capa), clients_capa)
    return model_capa

def get_final_performance(info):
    model_perf = {}
    for model, model_info in info.items():
        clients_perf = []
        for client in model_info['clients']:
            with open(client) as f:
                client =json.loads(f.read())
                perf = client['performance']['epoch_test_acc_cumulative'][0][-1]
                clients_perf.append(perf)
        model_perf[model] = (np.mean(clients_perf), clients_perf)
    return model_perf

def get_single_task_performance(info):
    model_perf = {}
    for model, path in info.items():
        res = []
        with open(path) as f:
            perf =json.loads(f.read())
            perf = perf['test_accuracy']
            for p in perf:
                if p[1] == 19:
                    res.append(p[-1])

        model_perf[model] = np.mean(res), res
    return model_perf

class ForkedPdb(pdb.Pdb):
	"""A Pdb subclass that may be used
	from a forked multiprocessing child

	"""
	def interaction(self, *args, **kwargs):
		_stdin = sys.stdin
		try:
			sys.stdin = open('/dev/stdin')
			pdb.Pdb.interaction(self, *args, **kwargs)
		finally:
			sys.stdin = _stdin
