import os 
import pdb
import json
import glob
import pickle
import pprint
import random
import codecs
import matplotlib
import numpy as np
import pandas as pd
import matplotlib._color_data as mcd
import matplotlib.pyplot as plt

from datetime import datetime
from data.mixture_loader.mixture import get

from common.utils import *

pp = pprint.PrettyPrinter(indent=4)
    
def visualize(info):
    num_rounds = 0
    num_tasks = 0
    path = info['path']
    config = info['config']
    perf_models = {}
    model_clients = {}
    for i, model in enumerate(info['models']):
        perf_clients = {}
        model_name = model['name']
        model_clients[model_name] = glob.glob(os.path.join(path, model['output'], 'client*'))
        for client in glob.glob(os.path.join(path, model['output'], 'client*')):
            with open(client) as f:
                client =json.loads(f.read())
                if num_rounds == 0:
                    num_rounds = client['options']['num_rounds']
                if num_tasks == 0:
                    num_tasks = client['data_info']['num_tasks']
                if config['type']=='forgetting':
                    perf = client['performance_watch']
                else:
                    perf = client['performance']
                if config['tasks'] == -1:
                    tasks = perf.keys()
                else:
                    tasks = [str(t) for t in config['tasks']]
                for tid in tasks:
                    if tid not in perf_clients:
                        perf_clients[tid] = {
                            'color': model['color'], # colors[i], 
                            'perf': []
                        }
                    perf_clients[tid]['perf'].append(perf[tid])
        # averaging
        for tid, info in perf_clients.items():
            # pdb.set_trace()
            perf_clients[tid]['perf'] = np.mean(info['perf'], axis=0)*100
        perf_models[model_name] = perf_clients

    if config['type'] == 'forgetting':
        _visualizae(perf_models, config, num_rounds, num_tasks)
    elif config['type'] == 'current':
        _visualize_current(perf_models, config, num_rounds, num_tasks)
    elif config['type'] == 'perf_cap':
        _visualize_perf_cap(model_clients, perf_models, config, num_rounds, num_tasks)
    elif config['type'] == 'perf_comm':
        _visualize_perf_comm(model_clients, perf_models, config, num_rounds, num_tasks)


def _visualize_perf_cap(model_clients, perf_models, config, num_rounds, num_tasks):

    matplotlib.rcParams.update({'font.size': 16})
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(config['size'][0], config['size'][1])
    axes = plt.gca()
    # axes.set_ylim([0.5, 1.00])

    plt.grid(False)
    # plt.xticks(np.arange(0, num_rounds*num_tasks+1, config['xtick']), rotation=45)
    # plt.yticks(np.arange(50, 100+1, config['ytick']))
    
    plt.title(config['title'], fontsize=20)
    plt.xlabel(config['x_axis_label'], fontsize=20)
    plt.ylabel(config['y_axis_label'], fontsize=20)
    
    x = []
    y = []
    c = []
    m = []
    colors = ['blue', 'green', 'red', 'yellow']
    markers = ['d', 's', '*', 'v', '<', '>', 'P', 'X', 'h', '^']
    cnt = -1
    for mname, clients in model_clients.items():
        cnt += 1
        cap = round(float(get_capacity(clients))) 
        perf = round(float(get_accuracy(clients)))

        print(mname, cap, perf)
        plt.scatter(cap, perf, label=mname ) #colors[cnt], marker=markers[cnt]  markersize=config['markersize'],label=mname, label='{}-task:{}'.format(key, tid)
    plt.legend()


def _visualize_perf_comm(model_clients, perf_models, config, num_rounds, num_tasks):

    matplotlib.rcParams.update({'font.size': 16})
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(config['size'][0], config['size'][1])
    axes = plt.gca()
    # axes.set_ylim([0.5, 1.00])

    plt.grid(False)
    # plt.xticks(np.arange(0, num_rounds*num_tasks+1, config['xtick']), rotation=45)
    # plt.yticks(np.arange(50, 100+1, config['ytick']))
    
    plt.title(config['title'], fontsize=20)
    plt.xlabel(config['x_axis_label'], fontsize=20)
    plt.ylabel(config['y_axis_label'], fontsize=20)
    
    x = []
    y = []
    c = []
    m = []
    colors = ['blue', 'green', 'red', 'yellow']
    markers = ['d', 's', '*', 'v', '<', '>', 'P', 'X', 'h', '^']
    cnt = -1
    for mname, clients in model_clients.items():
        cnt += 1
        cap = round(float(get_comm_costs(clients))) 
        perf = round(float(get_accuracy(clients)))

        print(mname, cap, perf)
        plt.scatter(cap, perf, label=mname) #colors[cnt], marker=markers[cnt]  markersize=config['markersize'],label=mname, label='{}-task:{}'.format(key, tid)
    plt.legend()


def _visualizae(perf_models, config, num_rounds, num_tasks):

    matplotlib.rcParams.update({'font.size': 16})
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(config['size'][0], config['size'][1])
    axes = plt.gca()
    # axes.set_ylim([0.5, 1.00])

    plt.grid(False)
    # plt.xticks(np.arange(0, num_rounds*num_tasks+1, config['xtick']), rotation=45)
    # plt.yticks(np.arange(50, 100+1, config['ytick']))
    
    plt.title(config['title'], fontsize=20)
    plt.xlabel(config['x_axis_label'], fontsize=20)
    plt.ylabel(config['y_axis_label'], fontsize=20)
    
    marker = ['d', 's', '*', 'v', '<', '>', 'P', 'X', 'h', '^']
    for key, model in perf_models.items():
        mid = -1
        for tid, info in model.items():
            tid = int(tid)
            offset = (num_rounds*tid)
            y = info['perf']
            if config['type'] == 'forgetting':
                x = range(offset, offset+len(y))
            else:
                y = y[:num_rounds]
                x = range(offset, offset+len(y))
            plt.plot(x, y, info['color'],
                        marker=marker[mid], 
                        markersize=config['markersize'], 
                        label=key if tid==0 else None) # label='{}-task:{}'.format(key, tid)
    plt.legend()


def _visualize_current(perf_models, config, num_rounds, num_tasks):

    matplotlib.rcParams.update({'font.size': 16})
    fig, axs = plt.subplots(1, 5)
    fig.set_size_inches(config['size'][0], config['size'][1])
    row = 0
    col = 0
    mid = -1
    marker = ['o', '*', '^', 'v', '<', '>', 'P', 'X', 'h', '^']
    for key, model in perf_models.items():
        mid += 1
        for tid, info in model.items():
            tid = int(tid)
            if tid < 5: # if tid%2 == 0:
                offset = 0 # (num_rounds*tid)
                y = info['perf']
                if config['type'] == 'forgetting':
                    x = range(offset, offset+len(y))
                else:
                    y = y[:num_rounds]
                    x = range(offset, offset+len(y))
                # axs[row, col].title = config['title']
                # axs[row, col].xlabel =config['x_axis_label']
                # axs[row, col].ylabel =config['y_axis_label']
                axs[col].plot(x, y, info['color'], label=key,
                                        marker=marker[mid], markersize=config['markersize']) # label='{}-task:{}'.format(key, tid)
                axs[col].legend(loc='lower right')

                col += 1 
                if col == 5:
                    col = 0
                    row += 1
        row = 0
        col = 0
    
    plt.legend()

def get_summary_table(info):
    path = info['path']
    d = {
        'models': [],
        'Avg.Acc.': [],
        'Capacity.': [],
        'Costs(C)': [],
        'Costs(S)': [],
    }
    if info['details']:
        cols = [ 'lr', 'l1_hyp', 'federated', 'continual', 'model', 'base_architect', \
            'num_clients', 'num_rounds', 'num_epochs',   'overlapped', 'dataset']  
                # 'num_classes', 'num_pick_tasks'
                    # 'frac_clients', 'from_task_pool', 'pool_size', 'l1_mask', 'comm_rate'
        for c in cols:
            d[c] = []
    for m in info['models']:
        d['models'].append(m['name'])
        clients = glob.glob(os.path.join(path, m['output'],'client*'))
        acc = get_accuracy(clients) if len(clients)>0 else '-'
        cap = get_capacity(clients) if len(clients)>0 else '-'
        comm_client = get_comm_costs(clients) if len(clients)>0 and '(CL)' not in m['name'] else '-'
        d['Avg.Acc.'].append(acc) 
        # if 'ABC' in m['name']:
        #     d['Capacity.'].append('-') 
        # else:
        d['Capacity.'].append(cap) 
        d['Costs(C)'].append(comm_client)
        if len(glob.glob(os.path.join(path, m['output'],'server.txt')))>0:
            with open(os.path.join(path, m['output'],'server.txt')) as f:
                server=json.loads(f.read())
                if 'ABC' in m['name']  :
                    d['Costs(S)'].append('%.2f'%(30.0))
                else:
                    d['Costs(S)'].append(round(np.mean(server['comm_ratio'])*100.0, 2))
        elif '(CL)' not in m['name']: # to be removed    
            d['Costs(S)'].append('%.2f'%(100))
        else:
            d['Costs(S)'].append('-')
        if info['details']:
            if len(clients)>0:
                with open(clients[0]) as f:
                    client=json.loads(f.read())
                    opt = client['options']
                    for c in cols:
                        value = opt[c] if c in opt else '-'
                        d[c].append(value) 
            else:
                for c in cols:
                    d[c].append('-') 
    return pd.DataFrame(data=d)

def get_accuracy(clients):
    results = {}
    perfs = {}
    for client in clients:
        with open(client) as f:
            client=json.loads(f.read())
            performance=client['performance_final'] if len(client['performance_final'])>0 else client['performance_watch']
            for tid in performance.keys():
                if tid not in perfs:
                    perfs[tid] = []
                perfs[tid].append(performance[tid][-1])
    results['task_avg'] = {'task_{}'.format(tid):np.mean(p) for tid, p in perfs.items()}
    results['total_avg'] = np.mean([results['task_avg'][tid] for tid in results['task_avg']]) 
    return '%.2f'%(round(results['total_avg'] * 100.0, 2)) # str(round(results['total_avg'] * 100, 2))

def get_capacity(clients):
    clients_capa = []
    for client in clients:
        with open(client) as f:
            # client_capa = []
            client =json.loads(f.read())
            capa = client['mem_ratio'][-1] if len(client['mem_ratio'])>0 else 1.0
            clients_capa.append(capa)
    return '%.2f'%(round(np.mean(clients_capa)*100.0, 2)) #str(round(np.mean(clients_capa) * 100, 2))

def get_comm_costs(clients):
    clients_comm = []
    for client in clients:
        with open(client) as f:
            # client_capa = []
            client =json.loads(f.read())
            comm = np.mean(client['comm_ratio']) if len(client['comm_ratio'])>0 else 1.0
            clients_comm.append(comm)
    return '%.2f'%(round(np.mean(clients_comm)*100.0, 2))

def get_used_tasks(path):
    tasks = set()
    for client in glob.glob(os.path.join(path,'client*')):
        with open(client) as f:
            client=json.loads(f.read())
            for t in client['task_info']:
                tasks.add(t)
    print('chosen tasks ({}):'.format(len(tasks)), tasks)

def get_single_task_performance(path):
    res = []
    for client in glob.glob(os.path.join(path,'client*')):
        with open(client) as f:
            stl =json.loads(f.read())
            tasks = stl['performance']
            for tid, task in tasks.items():
                res.append(task[-1])
    return np.mean(res)*100
    
