"""
AutoML 2022 Conference Submission - Code Supplement
Evolutionary Search Algorithm Comparison

SuperNetwork: OFA MobileNetV3 (ofa_mbv3_d234_e346_k357_w1.0)
"""

# Imports
import argparse
import csv
import json
from datetime import datetime
import numpy as np
import pandas as pd
import copy
import pickle
import uuid
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import numpy as np
import os

# OFA Specific Imports
from ofa.tutorial.latency_table import LatencyEstimator
from ofa.imagenet_codebase.data_providers.imagenet import ImagenetDataProvider
from ofa.imagenet_codebase.run_manager import ImagenetRunConfig
from ofa.imagenet_codebase.run_manager import RunManager
import ofa

# LINAS Specific Imports
from linas.manager import ParameterManager
from linas.evaluation_module.predictor import MobileNetAccuracyPredictor, MobileNetLatencyPredictor
from linas.search_module.search import SearchAlgoManager, ProblemMultiObjective
from linas.analytics_module.visualize import collect_hv, load_csv_to_df


class OFARunner:
    '''
    The OFARunner is responsible for 'running' the subnetwork evaluation.
    '''
    def __init__(self, supernet, model_dir, lut, acc_predictor, macs_predictor,
                 latency_predictor, imagenetpath):

        self.supernet = supernet
        self.model_dir = model_dir
        self.acc_predictor = acc_predictor
        self.macs_predictor = macs_predictor
        self.latency_predictor = latency_predictor
        if isinstance(lut, dict):
            self.lut = lut
        else:
            with open(lut, 'r') as f:
                self.lut = json.load(f)
        self.latencyEstimator = LatencyEstimator(url=self.lut)
        self.width = float(supernet[-3:])

        # Validation setup
        self.target = 'cpu'
        self.test_size = None
        ImagenetDataProvider.DEFAULT_PATH = imagenetpath
        self.ofa_network = ofa.model_zoo.ofa_net(supernet, pretrained=True, model_dir=model_dir)
        self.run_config = ImagenetRunConfig(test_batch_size=64, n_worker=20)

    def get_subnet(self, subnet_cfg):

        self.ofa_network.set_active_subnet(ks=subnet_cfg['ks'],
                                           e=subnet_cfg['e'],
                                           d=subnet_cfg['d'])
        self.subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
        self.subnet.eval()
        return self.subnet

    def validate_accuracy_top1(self, subnet_cfg, target=None):

        if target is None:
            target = self.target
        subnet = self.get_subnet(subnet_cfg)
        folder_name = '.torch/tmp-{}'.format(uuid.uuid1().hex)
        run_manager = RunManager('{}/eval_subnet'.format(folder_name), subnet,
                                self.run_config, init=False, print_info=False)
        run_manager.reset_running_statistics(net=subnet)

        # Test sampled subnet
        self.run_config.data_provider.assign_active_img_size(subnet_cfg['r'][0])
        loss, top1, top5 = run_manager.validate(net=subnet, test_size=self.test_size, no_logs=True)
        return top1

    def estimate_accuracy_top1(self, subnet_cfg):

        # Ridge Predictor - 135 vector
        top1 = self.acc_predictor.predict_single(subnet_cfg)
        return top1

    def estimate_latency(self, subnet_cfg):

        # LUT Latency Predictor
        latency = self.latency_predictor.predict_single(subnet_cfg)
        return latency

    def measure_latency_lut(self, subnet_cfg):

        # LUT Latency Predictor
        latency = self.latencyEstimator.predict_network_latency_given_spec(subnet_cfg, width=self.width)
        return latency

    def estimate_accuracy_custom(self, subnet_cfg):

        # Ridge Predictor - 120 vector
        top1 = self.acc_predictor.predict_single(self.onehot_custom(subnet_cfg['ks'],
            subnet_cfg['e'], subnet_cfg['d']))

        return top1

    def construct_maps(self, keys):
        d = dict()
        keys = list(set(keys))
        for k in keys:
            if k not in d:
                d[k] = len(list(d.keys()))
        return d

    def onehot_custom(self, ks_list, ex_list, d_list):

        ks_map = self.construct_maps(keys=(3, 5, 7))
        ex_map = self.construct_maps(keys=(3, 4, 6))

        start = 0
        end = 4
        for d in d_list:
            for j in range(start+d, end):
                ks_list[j] = 0
                ex_list[j] = 0
            start += 4
            end += 4

        # convert to onehot
        ks_onehot = [0 for _ in range(60)]
        ex_onehot = [0 for _ in range(60)]

        for i in range(20):
            start = i * 3
            if ks_list[i] != 0:
                ks_onehot[start + ks_map[ks_list[i]]] = 1
            if ex_list[i] != 0:
                ex_onehot[start + ex_map[ex_list[i]]] = 1

        return np.array(ks_onehot + ex_onehot)


class UserEvaluationInterface:
    '''
    The interface class update is required to be updated for each unique SuperNetwork
    framework as it controls how evaluation calls are made from DyNAS-T

    Parameters
    ----------
    evaluator : class
        The 'runner' that performs the validation or prediction
    manager : class
        The DyNAS-T manager that translates between PyMoo and the parameter dict
    csv_path : string
        (Optional) The csv file that get written to during the subnetwork search
    '''

    def __init__(self, evaluator, manager, csv_path=None, validate=True):
        self.evaluator = evaluator
        self.manager = manager
        self.csv_path = csv_path
        self.validate = validate

    def eval_subnet(self, x):
        # PyMoo vector to Elastic Parameter Mapping
        param_dict = self.manager.translate2param(x)

        sample = {
            'wid': None,
            'ks': param_dict['ks'],
            'e': param_dict['e'],
            'd': param_dict['d'],
            'r': [224]
        }
        # Prevents accidental re-mapping of sample in certain situations
        subnet_sample = copy.deepcopy(sample)

        # Note for reproducibility, stage-wise look-up-table used for latency
        if self.validate == True:
            print('[Info] Making validation measurement.')
            latency = self.evaluator.measure_latency_lut(sample)
            top1 = self.evaluator.validate_accuracy_top1(sample)
        else:
            latency = self.evaluator.measure_latency_lut(sample)
            top1 = self.evaluator.estimate_accuracy_custom(sample)

        if self.csv_path:
            with open(self.csv_path, 'a') as f:
                writer = csv.writer(f)
                date = str(datetime.now())
                result = [subnet_sample, date, float(latency), float(top1)]
                writer.writerow(result)

        # PyMoo only minimizes objectives, thus accuracy needs to be negative/inverse
        # 100/(top1-70) scales the top-1 results in the objective search space for better ref dir performance
        return sample, latency, 100/(top1-70)


def main(args):

    # --------------------------------
    # OFA MobileNetV3 <-> LINAS Interface Setup
    # --------------------------------

    # Define SuperNetwork Parameter Dictionary and Instantiate Manager
    supernet_parameters = {'ks'  :  {'count' : 20, 'vars' : [3, 5, 7]},
                           'e'   :  {'count' : 20, 'vars' : [3, 4, 6]},
                           'd'   :  {'count' : 5,  'vars' : [2, 3, 4]} }
    supernet_manager = ParameterManager(param_dict=supernet_parameters,
                                        seed=args.seed)

    supernet = args.supernet
    print('[Info] Loading Latency LUT.')
    with open(args.lut_path, 'r') as f:
        lut = json.load(f)
    supernet = lut['metadata']['_net']
    assert supernet == args.supernet

    print('[Info] Loading pre-trained accuracy predictor.')
    with open(args.acc_predictor_path, 'rb') as f:
        acc_pred = pickle.load(f)

    os.makedirs('results', exist_ok=True)

    # Instatiate objective 'runner'
    runner = OFARunner(supernet=supernet,
                       model_dir=args.model_dir,
                       lut=lut,
                       acc_predictor=acc_pred,
                       macs_predictor=None,
                       latency_predictor=None,
                       imagenetpath=args.dataset_path)

    search_algo_list = ['nsga2', 'age', 'ctaea', 'moead', 'unsga3', 'random']
    num_runs = args.num_runs
    seed_list = [x for x in range(num_runs)]


    for search_algo in search_algo_list:
        for seed in seed_list:

            print(f'[Info] Running {search_algo} with seed {seed} and {args.num_evals} evaluations.')

            # Clear/touch output file
            output_csv = f'./results/trial_{search_algo}_seed{seed}_{args.num_evals}.csv'
            with open(output_csv, 'w') as f:
                writer = csv.writer(f)

            supernet_manager = ParameterManager(param_dict=supernet_parameters, seed=seed)

            evaluation_interface = UserEvaluationInterface(evaluator=runner, manager=supernet_manager,
                csv_path=output_csv, validate=args.acc_validate)

            problem = ProblemMultiObjective(evaluation_interface=evaluation_interface,
                                            param_count=supernet_manager.param_count,
                                            param_upperbound=supernet_manager.param_upperbound)

            search_manager = SearchAlgoManager(algorithm=search_algo, seed=seed)

            if search_algo == 'nsga2':
                search_manager.configure_nsga2(population=args.population, num_evals=args.num_evals)

            elif search_algo == 'age':
                search_manager.configure_age(population=args.population, num_evals=args.num_evals)

            elif search_algo == 'ctaea':
                search_manager.configure_ctaea(num_evals=args.num_evals)

            elif search_algo == 'moead':
                search_manager.configure_moead(num_evals=args.num_evals)

            elif search_algo == 'unsga3':
                search_manager.configure_unsga3(population=args.population, num_evals=args.num_evals)

            elif search_algo == 'random':
                samples = list()
                while len(samples) < args.num_evals:
                    samples.append(supernet_manager.random_sample())
                for individual in samples:
                    evaluation_interface.eval_subnet(individual)

            # Run the search!
            if search_algo != 'random':
                output = search_manager.run_search(problem)

    if args.plot_results:

        print('[Info] Generating hypervolume plot...')

        fig, ax = plt.subplots(figsize=(6,4))

        colors = ['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f','#bcbd22','#17becf']
        lines = ['-','--','dashdot',(0,(1,1)),(0,(3,1,1,1)),(0, (3,1,1,1,1,1)),'-','--','dashdot','-','--']
        c_idx = 0

        for search_algo in search_algo_list:
            hv_dict = dict()
            for seed in seed_list:
                df_results = load_csv_to_df(f'./results/trial_{search_algo}_seed{seed}_{args.num_evals}.csv')
                hv_dict[seed], interval = collect_hv(df_results, max_idx=args.num_evals, ref_point=[50, -70])

            df_hv = pd.DataFrame(np.vstack([np.array(hv_dict[key]).squeeze() for key, _ in hv_dict.items()]).T)
            df_hv['mean'] = df_hv.mean(axis=1)
            df_hv['std'] = df_hv.std(axis=1)/(len(hv_dict)**0.5)

            ax.plot(interval, df_hv['mean'], label=search_algo, color=colors[c_idx], linewidth=2, linestyle=lines[c_idx])
            ax.fill_between(interval, df_hv['mean']-df_hv['std'], df_hv['mean']+df_hv['std'],
                color=colors[c_idx], alpha=0.2)

            c_idx += 1

        ax.set_xlim(args.population, args.num_evals)
        ax.set_ylim(150,230)

        ax.set_xlabel('Evaluation Count', fontsize=13)
        ax.set_ylabel('Hypervolume', fontsize=13)
        ax.legend(fancybox=True, fontsize=10, framealpha=1, borderpad=0.2, loc='upper left')
        ax.grid(True, alpha=0.2)
        ax.set_xscale('log')
        formatter = ScalarFormatter()
        formatter.set_scientific(False)
        ax.xaxis.set_major_formatter(formatter)
        plt.savefig('results/hypervolume_ea_tournament.png', bbox_inches='tight', pad_inches=0, dpi=150);



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--supernet', default='ofa_mbv3_d234_e346_k357_w1.0')
    parser.add_argument('--model_dir', default='collateral/')
    parser.add_argument('--dataset_path', help='The path of dataset (e.g. ImageNet) https://image-net.org/download.php',
                        type=str, default='/datasets/imagenet-ilsvrc2012')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--lut_path', default='collateral/latency_lut_a100_b128.json', help='path to latency look-up table (required)')
    parser.add_argument('--num_evals', default=20000, type=int, help='Number of TOTAL evaluations to make.')
    parser.add_argument('--population', default=50, type=int, help='Population size for each generation, if applicable.')
    parser.add_argument('--num_runs', default=5, type=int, help='Number of runs per search algorithm each with a unique seed.')
    parser.add_argument('--verbose', action='store_true', help='Flag to control output')
    parser.add_argument('--plot_results', action='store_true', help='Flag to control whether to plot results')
    parser.add_argument('--acc_validate', action='store_true', help='Flag to control whether acc is predicted or validated')
    parser.add_argument('--acc_predictor_path', default='collateral/ofa_mbv3_accuracy_predictor_ridge.pkl', help='Path to pre-trained accuracy predictor.')

    args = parser.parse_args()

    print('\n'+'-'*40)
    print('Multi-Objective Search Starting:')
    print('-'*40)

    main(args)

