"""
AutoML 2022 Conference Submission - Code Supplement
LINAS Search Algorithm Comparison

SuperNetwork: OFA MobileNetV3 (ofa_mbv3_d234_e346_k357_w1.0)
"""

# Imports
import argparse
import csv
import json
from datetime import datetime
import numpy as np
import pandas as pd
import copy
import pickle
import uuid
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import numpy as np
import os

# OFA Specific Imports
from ofa.tutorial.latency_table import LatencyEstimator
from ofa.imagenet_codebase.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_codebase.run_manager import ImagenetRunConfig
from ofa.imagenet_codebase.run_manager import RunManager
import ofa

# LINAS Specific Imports
from linas.manager import ParameterManager
from linas.evaluation_module.predictor import MobileNetAccuracyPredictor, MobileNetLatencyPredictor
from linas.search_module.search import SearchAlgoManager, ProblemMultiObjective
from linas.analytics_module.visualize import collect_hv, load_csv_to_df, frontier_builder


class OFARunner:
    '''
    The OFARunner is responsible for 'running' the subnetwork evaluation.
    '''
    def __init__(self, supernet, model_dir, lut, acc_predictor, macs_predictor,
                 latency_predictor, imagenetpath):

        self.supernet = supernet
        self.model_dir = model_dir
        self.acc_predictor = acc_predictor
        self.macs_predictor = macs_predictor
        self.latency_predictor = latency_predictor
        if isinstance(lut, dict):
            self.lut = lut
        else:
            with open(lut, 'r') as f:
                self.lut = json.load(f)
        self.latencyEstimator = LatencyEstimator(url=self.lut)
        self.width = float(supernet[-3:])

        # Validation setup
        self.target = 'cpu'
        self.test_size = None
        ImagenetDataProvider.DEFAULT_PATH = imagenetpath
        self.ofa_network = ofa.model_zoo.ofa_net(supernet, pretrained=True, model_dir=model_dir)
        self.run_config = ImagenetRunConfig(test_batch_size=64, n_worker=20)

    def get_subnet(self, subnet_cfg):

        self.ofa_network.set_active_subnet(ks=subnet_cfg['ks'],
                                           e=subnet_cfg['e'],
                                           d=subnet_cfg['d'])
        self.subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
        self.subnet.eval()
        return self.subnet

    def validate_accuracy_top1(self, subnet_cfg, target=None):

        if target is None:
            target = self.target
        subnet = self.get_subnet(subnet_cfg)
        folder_name = '.torch/tmp-{}'.format(uuid.uuid1().hex)
        run_manager = RunManager('{}/eval_subnet'.format(folder_name), subnet,
                                self.run_config, init=False, print_info=False)
        run_manager.reset_running_statistics(net=subnet)

        # Test sampled subnet
        self.run_config.data_provider.assign_active_img_size(subnet_cfg['r'][0])
        loss, top1, top5 = run_manager.validate(net=subnet, test_size=self.test_size, no_logs=True)
        return top1

    def estimate_accuracy_top1(self, subnet_cfg):

        # Ridge Predictor - 135 vector
        top1 = self.acc_predictor.predict_single(subnet_cfg)
        return top1

    def estimate_latency(self, subnet_cfg):

        # LUT Latency Predictor
        latency = self.latency_predictor.predict_single(subnet_cfg)
        return latency

    def measure_latency_lut(self, subnet_cfg):

        # LUT Latency Predictor
        latency = self.latencyEstimator.predict_network_latency_given_spec(subnet_cfg, width=self.width)
        return latency

    def estimate_accuracy_custom(self, subnet_cfg):

        # Ridge Predictor - 120 vector
        top1 = self.acc_predictor.predict_single(self.onehot_custom(subnet_cfg['ks'],
            subnet_cfg['e'], subnet_cfg['d']))

        return top1

    def construct_maps(self, keys):
        d = dict()
        keys = list(set(keys))
        for k in keys:
            if k not in d:
                d[k] = len(list(d.keys()))
        return d

    def onehot_custom(self, ks_list, ex_list, d_list):

        ks_map = self.construct_maps(keys=(3, 5, 7))
        ex_map = self.construct_maps(keys=(3, 4, 6))

        start = 0
        end = 4
        for d in d_list:
            for j in range(start+d, end):
                ks_list[j] = 0
                ex_list[j] = 0
            start += 4
            end += 4

        # convert to onehot
        ks_onehot = [0 for _ in range(60)]
        ex_onehot = [0 for _ in range(60)]

        for i in range(20):
            start = i * 3
            if ks_list[i] != 0:
                ks_onehot[start + ks_map[ks_list[i]]] = 1
            if ex_list[i] != 0:
                ex_onehot[start + ex_map[ex_list[i]]] = 1

        return np.array(ks_onehot + ex_onehot)


class UserEvaluationInterface:
    '''
    The interface class update is required to be updated for each unique SuperNetwork
    framework as it controls how evaluation calls are made from DyNAS-T

    Parameters
    ----------
    evaluator : class
        The 'runner' that performs the validation or prediction
    manager : class
        The DyNAS-T manager that translates between PyMoo and the parameter dict
    csv_path : string
        (Optional) The csv file that get written to during the subnetwork search
    '''

    def __init__(self, evaluator, manager, csv_path=None, validate=True, linas_pred=False):
        self.evaluator = evaluator
        self.manager = manager
        self.csv_path = csv_path
        self.validate = validate
        self.linas_pred = linas_pred

    def eval_subnet(self, x):
        # PyMoo vector to Elastic Parameter Mapping
        param_dict = self.manager.translate2param(x)

        sample = {
            'wid': None,
            'ks': param_dict['ks'],
            'e': param_dict['e'],
            'd': param_dict['d'],
            'r': [224]
        }
        # Prevents accidental re-mapping of sample in certain situations
        subnet_sample = copy.deepcopy(sample)

        # Note for reproducibility, stage-wise look-up-table used for latency
        if self.linas_pred == True:
            latency = self.evaluator.estimate_latency(self.manager.onehot_generic(x))
            top1 = self.evaluator.estimate_accuracy_top1(self.manager.onehot_generic(x))
        elif self.validate == True:
            print('[Info] Making validation measurement.')
            latency = self.evaluator.measure_latency_lut(sample)
            top1 = self.evaluator.validate_accuracy_top1(sample)
        else:
            latency = self.evaluator.measure_latency_lut(sample)
            top1 = self.evaluator.estimate_accuracy_custom(sample)

        if self.csv_path:
            with open(self.csv_path, 'a') as f:
                writer = csv.writer(f)
                date = str(datetime.now())
                result = [subnet_sample, date, float(latency), float(top1)]
                writer.writerow(result)

        # PyMoo only minimizes objectives, thus accuracy needs to be negative/inverse
        # Requires format: subnetwork, objective x, objective y
        return sample, latency, 100/(top1-70)



def main(args):

    # --------------------------------
    # OFA MobileNetV3 <-> LINAS Interface Setup
    # --------------------------------

    # Define SuperNetwork Parameter Dictionary and Instantiate Manager
    supernet_parameters = {'ks'  :  {'count' : 20, 'vars' : [3, 5, 7]},
                           'e'   :  {'count' : 20, 'vars' : [3, 4, 6]},
                           'd'   :  {'count' : 5,  'vars' : [2, 3, 4]} }
    supernet_manager = ParameterManager(param_dict=supernet_parameters,
                                        seed=args.seed)

    supernet = args.supernet
    print('[Info] Loading Latency LUT.')
    with open(args.lut_path, 'r') as f:
        lut = json.load(f)
    supernet = lut['metadata']['_net']
    assert supernet == args.supernet

    print('[Info] Loading pre-trained accuracy predictor.')
    with open(args.acc_predictor_path, 'rb') as f:
        acc_pred = pickle.load(f)

    os.makedirs('results', exist_ok=True)

    print('[Info] Generating search results for NSGA-II and Random Search.')

    runner = OFARunner(supernet=supernet,
                       model_dir=args.model_dir,
                       lut=lut,
                       acc_predictor=acc_pred,
                       macs_predictor=None,
                       latency_predictor=None,
                       imagenetpath=args.dataset_path)

    search_algo_list = ['nsga2', 'random']
    num_runs = args.num_runs
    seed_list = [x for x in range(num_runs)]


    for search_algo in search_algo_list:
        for seed in seed_list:

            print(f'[Info] Running {search_algo} with seed {seed} and {args.num_evals} evaluations.')

            # Clear/touch output file
            output_csv = f'./results/compare_{search_algo}_seed{seed}_{args.num_evals}.csv'
            with open(output_csv, 'w') as f:
                writer = csv.writer(f)

            supernet_manager = ParameterManager(param_dict=supernet_parameters, seed=seed)

            evaluation_interface = UserEvaluationInterface(evaluator=runner, manager=supernet_manager,
                csv_path=output_csv, validate=args.acc_validate)

            problem = ProblemMultiObjective(evaluation_interface=evaluation_interface,
                                            param_count=supernet_manager.param_count,
                                            param_upperbound=supernet_manager.param_upperbound)

            search_manager = SearchAlgoManager(algorithm=search_algo, seed=seed)

            if search_algo == 'nsga2':
                search_manager.configure_nsga2(population=args.population, num_evals=args.num_evals)
                output = search_manager.run_search(problem)

            elif search_algo == 'random':
                samples = list()
                while len(samples) < args.num_evals:
                    samples.append(supernet_manager.random_sample())
                for individual in samples:
                    evaluation_interface.eval_subnet(individual)


    print('[Info] Generating search results for LINAS.')
    num_loops = 40
    for seed in seed_list:

        # Clear/touch output file
        validated_population = f'./results/compare_linas_seed{seed}_{args.num_evals}.csv'
        with open(validated_population, 'w') as f:
            writer = csv.writer(f)

        supernet_manager = ParameterManager(param_dict=supernet_parameters, seed=seed)

        runner_validator = OFARunner(supernet=supernet,
                    model_dir=args.model_dir,
                    lut=lut,
                    acc_predictor=acc_pred,
                    macs_predictor=None,
                    latency_predictor=None,
                    imagenetpath=args.dataset_path)

        validation_interface = UserEvaluationInterface(evaluator=runner_validator, manager=supernet_manager,
                    csv_path=validated_population, validate=args.acc_validate)

        last_population = [supernet_manager.random_sample() for _ in range(args.population)]

        for loop in range(1, num_loops+1):
            print(f'[Info] Starting LINAS loop {loop} of {num_loops} for seed {seed}.')

            for individual in last_population:
                print(individual)
                validation_interface.eval_subnet(individual)

            print('[Info] Training "approximate" latency predictor.')
            df = supernet_manager.import_csv(validated_population, config='config', objective='latency',
                column_names=['config','date','latency','top1'])
            features, labels = supernet_manager.create_training_set(df)
            lat_pred_linas = MobileNetLatencyPredictor()
            lat_pred_linas.train(features, labels)

            print('[Info] Training "approximate" accuracy predictor.')
            df = supernet_manager.import_csv(validated_population, config='config', objective='top1',
                column_names=['config','date','latency','top1'])
            features, labels = supernet_manager.create_training_set(df)
            acc_pred_linas = MobileNetAccuracyPredictor()
            acc_pred_linas.train(features, labels)

            runner_predictor = OFARunner(supernet=supernet,
                                    model_dir=args.model_dir,
                                    lut=lut,
                                    macs_predictor=None,
                                    imagenetpath=args.dataset_path,
                                    acc_predictor=acc_pred_linas,
                                    latency_predictor=lat_pred_linas)

            prediction_interface = UserEvaluationInterface(evaluator=runner_predictor,
                                                        manager=supernet_manager,
                                                        linas_pred=True)

            # Instantiate Multi-Objective Problem Class
            problem = ProblemMultiObjective(evaluation_interface=prediction_interface,
                                            param_count=supernet_manager.param_count,
                                            param_upperbound=supernet_manager.param_upperbound)

            # Instantiate Search Manager
            search_manager = SearchAlgoManager(algorithm='nsga2',
                                            seed=args.seed)
            search_manager.configure_nsga2(population=args.population,
                                        num_evals=20000)

            # Run the search!
            output = search_manager.run_search(problem)
            last_population = output.pop.get('X')




    if args.plot_results:

        print('[Info] Generating hypervolume plot...')

        fig, ax = plt.subplots(figsize=(6,4))

        colors = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f','#bcbd22','#17becf']
        lines = ['-','--','dashdot',(0,(1,1)),(0,(3,1,1,1)),(0, (3,1,1,1,1,1)),'-','--','dashdot','-','--']
        c_idx = 0

        search_algo_list = ['linas', 'nsga2', 'random']

        for search_algo in search_algo_list:
            hv_dict = dict()
            for seed in seed_list:
                df_results = load_csv_to_df(f'./results/compare_{search_algo}_seed{seed}_{args.num_evals}.csv')
                hv_dict[seed], interval = collect_hv(df_results, max_idx=args.num_evals, ref_point=[50, -70])

            df_hv = pd.DataFrame(np.vstack([np.array(hv_dict[key]).squeeze() for key, _ in hv_dict.items()]).T)
            df_hv['mean'] = df_hv.mean(axis=1)
            df_hv['std'] = df_hv.std(axis=1)/(len(hv_dict)**0.5)

            ax.plot(interval, df_hv['mean'], label=search_algo, color=colors[c_idx], linewidth=2, linestyle=lines[c_idx])
            ax.fill_between(interval, df_hv['mean']-df_hv['std'], df_hv['mean']+df_hv['std'],
                color=colors[c_idx], alpha=0.2)

            c_idx += 1

        ax.set_xlim(args.population, args.num_evals)
        ax.set_ylim(150,230)

        ax.set_xlabel('Evaluation Count', fontsize=13)
        ax.set_ylabel('Hypervolume', fontsize=13)
        ax.legend(fancybox=True, fontsize=10, framealpha=1, borderpad=0.2, loc='lower right')
        ax.grid(True, alpha=0.2)
        #ax.set_xscale('log')
        formatter = ScalarFormatter()
        formatter.set_scientific(False)
        ax.xaxis.set_major_formatter(formatter)
        plt.savefig('results/hypervolume_linas_comparison.png', bbox_inches='tight', pad_inches=0, dpi=150);


        print('[Info] Generating scatter plot...')

        # Limit evalution count
        index_slice = 250

        import alphashape
        from descartes import PolygonPatch
        from matplotlib.cm import ScalarMappable

        df_nsga = load_csv_to_df(f'./results/compare_nsga2_seed0_{args.num_evals}.csv')[:index_slice]
        df_linas = load_csv_to_df(f'./results/compare_linas_seed0_{args.num_evals}.csv')[:index_slice]
        df_linas_front = frontier_builder(df_linas)
        df_random = load_csv_to_df(f'./results/compare_random_seed0_{args.num_evals}.csv')[:index_slice]

        fig, ax = plt.subplots(1, 2, figsize=(10,4), gridspec_kw={'width_ratios': [2.5, 3]})
        cm = plt.cm.get_cmap('viridis_r')

        # LINAS plot
        data=df_linas[['latency', 'accuracy']]
        count = [x for x in range(len(data))]
        x = data['latency']
        y = data['accuracy']

        ax[0].set_title('LINAS')
        ax[0].scatter(x, y, marker='D', alpha=0.8, c=count, cmap=cm, label='LINAS Discovered DNN Model', s=6)
        ax[0].set_ylabel('Top-1 Accuracy (%)', fontsize=13)
        ax[0].plot(df_linas_front['latency'], df_linas_front['accuracy'],
                color='red', linestyle='--', label='LINAS Pareto front')

        # NSGA-II plot
        data1=df_nsga[['latency', 'accuracy']]
        count = [x for x in range(len(data1))]
        x = data1['latency']
        y = data1['accuracy']

        ax[1].set_title('NSGA-II')
        ax[1].scatter(x, y, marker='D', alpha=0.8, c=count, cmap=cm, label='NSGA-II Discovered DNN Model', s=6)
        ax[1].plot(df_linas_front['latency'], df_linas_front['accuracy'],
                color='red', linestyle='--', label='LINAS Pareto front')
        ax[1].get_yaxis().set_ticklabels([])

        cloud = list(df_random[['latency','accuracy']].to_records(index=False))
        alpha_shape = alphashape.alphashape(cloud, 0)

        for ax in fig.get_axes():
            ax.add_patch(PolygonPatch(alpha_shape, fill=None, alpha=0.8, linewidth=1.5,
                label='Random search boundary', linestyle='--'))
            ax.legend(fancybox=True, fontsize=10, framealpha=1, borderpad=0.2, loc='lower right')
            ax.set_ylim(72,77.5)
            ax.grid(True, alpha=0.3)
            ax.set_xlabel('Latency (ms)', fontsize=13)

        # Eval Count bar
        norm = plt.Normalize(0, len(data))
        sm = ScalarMappable(norm=norm, cmap=cm)
        cbar = fig.colorbar(sm, ax=ax, shrink=0.85)
        cbar.ax.set_title("         Evaluation\n  Count", fontsize=8)

        fig.tight_layout(pad=2)
        plt.subplots_adjust(wspace=0.07, hspace=0)
        plt.savefig('results/scatterplot_linas_nsga2_random.png', bbox_inches='tight', pad_inches=0, dpi=150);






if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--supernet', default='ofa_mbv3_d234_e346_k357_w1.0')
    parser.add_argument('--model_dir', default='collateral/')
    parser.add_argument('--dataset_path', help='The path of dataset (e.g. ImageNet) https://image-net.org/download.php',
                        type=str, default='/datasets/imagenet-ilsvrc2012')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--lut_path', default='collateral/latency_lut_a100_b128.json', help='path to latency look-up table (required)')
    parser.add_argument('--num_evals', default=2000, type=int, help='Number of TOTAL evaluations to make.')
    parser.add_argument('--population', default=50, type=int, help='Population size for each generation, if applicable.')
    parser.add_argument('--num_runs', default=5, type=int, help='Number of runs per search algorithm each with a unique seed.')
    parser.add_argument('--verbose', action='store_true', help='Flag to control output')
    parser.add_argument('--plot_results', action='store_true', help='Flag to control whether to plot results')
    parser.add_argument('--acc_validate', action='store_true', help='Flag to control whether acc is predicted or validated')
    parser.add_argument('--acc_predictor_path', default='collateral/ofa_mbv3_accuracy_predictor_ridge.pkl', help='Path to pre-trained accuracy predictor.')

    args = parser.parse_args()

    print('\n'+'-'*40)
    print('Multi-Objective Search Starting:')
    print('-'*40)

    main(args)

