import openvino as ov
# import os
# import networkx as nx
# import numpy as np
# import tensorflow as tf
import sys
# import openvino as ov


from numpy.lib._iotools import str2bool

sys.path.append('./')
sys.path.append('progressive_placers/')
sys.path.append('sim/')
sys.path.append('model/')
from model.progressive_placer import *

import argparse
import model.rl_params
from itertools import chain
from model.mp_progressive_nn import MessagePassingProgressiveNN
from model.simple_nn import *
from model.simple_graphs import *
from model.pp_item import *
from transformers import BertModel, BertTokenizer
# import openvino as ov



# from utils import *

def compile_model(model, device, core):
    return core.compile_model(model, device)


class ProgressivePlacerTest(object):

    def sim_reward(n_devs, p, G, xml_file):
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        texts = ["Hello, my dog is cute"]
        inputs = tokenizer(texts, return_tensors="pt", padding=True)
        core = ov.Core()
        model = core.read_model(model=xml_file)
        ops = model.get_ops()


        for i, node in enumerate(G.nodes()):
            # affinity = "CPU" if p[node] == 0 else "GPU.1"
            affinity = "CPU" if ((p[node] == 0) or (p[node] == 1)) else "GPU.1"
            ops[i].get_rt_info()["affinity"] = affinity

        # print(device)

        import time

        try:

            compiled_model = core.compile_model(model, 'HETERO:GPU.1,CPU')

            input_layer_1 = compiled_model.input(0)
            input_layer_2 = compiled_model.input(1)
            output_layer = compiled_model.output(0)
            input_data = (inputs['input_ids'], inputs['attention_mask'])

            run_times = []
            for i in range(10):
                start_time = time.perf_counter()
                output = compiled_model({input_layer_1.any_name: input_data[0], input_layer_2.any_name: input_data[1]})[output_layer]
                end_time = time.perf_counter()
                run_time = end_time - start_time
                if i > 4:
                    run_times.append(run_time)

            run_time_mean = np.mean(run_times)
            del core
            # print(f'device placement succeed')
            return run_time_mean
        except Exception as e1:
            print(e1)
            del core
            return 1e10
        
    # def sim_reward(n_devs, p, G, xml_file):
    #     core = ov.Core()
    #     model = core.read_model(model=xml_file)
    #     ops = model.get_ops()
    #     model = core.read_model(model=xml_file)
    #     ops = model.get_ops()
    #     # Check if all values are 1
    #     all_ones = all(value == 1 for value in p.values())
    #     # Check if all values are 0
    #     all_zeros = all(value == 0 for value in p.values())
    #     if all_ones or all_zeros:
    #         del core
    #         return 1e10
    #     else:
    #         for i, node in enumerate(G.nodes()):
    #             affinity = "CPU" if (p[node] == 0) else "GPU.1"
    #             ops[i].get_rt_info()["affinity"] = affinity
    #         import time
    #         try:
    #             compiled_model = core.compile_model(model, 'HETERO:GPU.1,CPU')
    #             input_layer = compiled_model.input(0)
    #             output_layer = compiled_model.output(0)
    #             input_data = np.random.randn(1, 3, 224, 224)
    #             run_times = []
    #             for i in range(10):
    #                 start_time = time.perf_counter()
    #                 output = compiled_model({input_layer.any_name: input_data})[output_layer]
    #                 end_time = time.perf_counter()
    #                 run_time = end_time - start_time
    #                 if i > 4:
    #                     run_times.append(run_time)
    #             run_time_mean = np.mean(run_times)
    #             del core
    #             return run_time_mean
    #         except Exception as e1:
    #             print(e1)
    #             del core
    #             return 1e10

    '''
    place everything on the last gpu (id: n_devs-1)
    '''

    @staticmethod
    def sim_single_gpu(n_devs, sim, p, G):
        start_times = {}
        for n in G.nodes():
            start_times[n] = 0.
        run_time = 0
        for _, d in p.items():
            if d != n_devs - 1:
                run_time += 1

        return run_time, start_times

    @staticmethod
    def sim_neigh_placement(n_devs, sim, p, G):
        start_times = {}
        for n in G.nodes():
            start_times[n] = 0.
        run_time = 0
        for n, d in p.items():
            for neigh in chain(G.neighbors(n), G.predecessors(n)):
                if p[neigh] != p[n]:
                    run_time += 1

        return run_time, start_times, [0] * 2

    @staticmethod
    def choose_model(model_name):

        # if model_name == 'supervised':
        #   nn_model = SupervisedSimpleNN
        if model_name == 'simple_nn':
            nn_model = SimpleNN
        # elif model_name == 'local_nn':
        #   nn_model = LocalProgressiveNN
        elif model_name == 'mp_nn':
            nn_model = MessagePassingProgressiveNN
        # elif model_name == 'or':
        #   nn_model = MessagePassingOneRewardNN
        else:
            raise Exception('%s not implemented model' % model_name)

        return nn_model

    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ('yes', 'true', 'True', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'False', 'f', 'n', '0'):
            return False

    def test(self, config):
        graph = config['graph']
        N = config['graph_size']
        n_devices = config['n_devs']
        m_name = config['m_name']
        # xml_file = f"/home/shukai/projects/Heterogeneous-Computing/GPN_OpenVINO/resnet50.xml"
        xml_file = f"./bert-base-uncased.xml"

        f = None
        sim = None
        # if graph in ['chain', 'crown', 'edge']:
        #   from old_simulator import LegacySimulator
        #   sim = LegacySimulator(None, False, n_devs=n_devs, override_ban=True)

        # TODO 1
        # create a graph - either syhthetic, or from a file
        if graph == 'chain':
            G = makeChainGraph(N, n_devices)
        elif graph == 'crown':
            G = makeCrownGraph(N, n_devices)
        elif graph == 'edge':
            G = makeEdgeGraph(N)
        else:
            progressive_graph = self.create_progressive_graph_from_xml(xml_file, n_devices)
            G = progressive_graph.G

        if not f:
            if config['rew_singlegpu']:
                f = ProgressivePlacerTest.sim_single_gpu
            elif config['rew_neigh_pl']:
                f = ProgressivePlacerTest.sim_neigh_placement
            else:
                f = lambda p: ProgressivePlacerTest.sim_reward(n_devices, p, G, xml_file)

        if config['eval'] is not None:
            _, r, ss, p = self.eval_placement(G, config['eval'], xml_file)

            fname = 'models/chrome-traces/%s/timeline.json' % (config['name'])
            # timeline_to_json(ss, p, fname)
        else:
            ProgressivePlacer().place(
                G, n_devices, ProgressivePlacerTest.choose_model(m_name),
                lambda *args, **kwargs: f(*args, **kwargs),
                config, xml_file)

    def create_progressive_graph_from_xml(self, xml_file, n_devices):
        core = ov.Core()
        model = core.read_model(model=xml_file)

        G = nx.DiGraph()

        # 添加节点
        for op in model.get_ops():
            node_id = op.get_friendly_name()
            G.add_node(node_id)

            # 设置节点属性
            G.nodes[node_id]['cost'] = 1  # 设置默认的cost值
            G.nodes[node_id]['out_size'] = 1  # 设置默认的out_size值
            G.nodes[node_id]['mem'] = 1  # 设置默认的mem值

        # 添加边
        for op in model.get_ops():
            node_id = op.get_friendly_name()
            for inp in op.input_values():
                if inp.get_node().get_friendly_name() != '':
                    G.add_edge(inp.get_node().get_friendly_name(), node_id)

        print(f'xml node {len(G.nodes)}')

        progressive_graph = ProgressiveGraph(G, n_devices, 'topo')
        return progressive_graph

    def eval_placement(self, G, placement, xml_file):
        run_time = ProgressivePlacerTest.sim_reward(G.n_devs, placement, G, xml_file)
        start_times = None
        mem_utils = None
        return placement, run_time, start_times, mem_utils

    def mul_graphs(self, config):
        from model.coord import Coordinator
        Coordinator().start(config, self.test)

    # def benchmark_policy(self, config):
    # from model.policy_benchmarker import PolicyBenchmarker
    # PolicyBenchmarker().start(config, self.test)


if __name__ == '__main__':

    # with open('./config/config.txt', 'r') as file:
    #     args = file.read().strip().split()
    # sys.argv[1:] = args

    core = ov.Core()
    print(f'devices {core.available_devices}')

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--name', '-n', type=str, default='test')
    parser.add_argument('--graph', '-g', type=str, default=None)
    parser.add_argument('--id', type=int, default=None)
    # for synthetic non-tensorflow graphs
    parser.add_argument('--graph-size', '-N', type=int, default=4)
    parser.add_argument('--pickled-inp-file', '-i', type=str, default=None, nargs='+')
    parser.add_argument('--mul-graphs', type=str, default=None, nargs='+')
    parser.add_argument('--dataset-folder', type=str, default=None,
                        help='Use this to denote a folder containing a dataset like cifar10. '
                             'Each subfolder will be checked for input.pkl files')

    parser.add_argument('--dataset', type=str, default=None,
                        help='Use this to denote a folder containing a dataset like cifar10. '
                             'Each subfolder will be checked for input.pkl files')

    parser.add_argument('--n-devs', type=int, default=2)
    parser.add_argument('--model-folder-prefix', type=str, default='', dest='model_folder_prefix')

    # progressive placer model args
    parser.add_argument('--m-name', type=str, default='mp_nn')
    parser.add_argument('--n-peers', type=int, default=None)
    parser.add_argument('--agg-msgs', type=str2bool, dest='agg_msgs')
    parser.add_argument('--no-msg-passing', type=str2bool, dest='no_msg_passing')
    parser.add_argument('--radial-mp', type=int, default=None)
    parser.add_argument('--tri-agg', type=str2bool, dest='tri_agg')
    parser.add_argument('--sage', type=str2bool, dest='sage')
    parser.add_argument('--sage-hops', type=int, dest='sage_hops')
    parser.add_argument('--sage-sample-ratio', type=float, dest='sage_sample_ratio')
    parser.add_argument('--sage-dropout-rate', type=float, dest='sage_dropout_rate')
    parser.add_argument('--sage-aggregation', type=str, dest='sage_aggregation', default='mean')

    parser.add_argument('--sage-position-aware', type=str2bool, dest='sage_position_aware')
    parser.add_argument('--use-single-layer-perceptron', type=str2bool, dest='use_single_layer_perceptron')
    parser.add_argument('--pgnn-c', type=float, dest='pgnn_c')
    parser.add_argument('--pgnn-neigh-cutoff', type=int, dest='pgnn_neigh_cutoff')
    parser.add_argument('--pgnn-anchor-exponent', type=int, dest='pgnn_anchor_exponent')
    parser.add_argument('--pgnn-aggregation', type=str, dest='pgnn_aggregation', default='max')
    parser.add_argument('--reinit-model', type=str2bool, dest='reinit_model')

    # training args
    parser.add_argument('--n-eps', type=int, default=int(1e9))
    parser.add_argument('--max-rnds', type=int, default=None)
    parser.add_argument('--disc-factor', type=float, default=1.)
    parser.add_argument('--vary-init-state', dest='vary_init_state', action='store_true')
    parser.add_argument('--zero-placement-init', dest='zero_placement_init', action='store_true')
    parser.add_argument('--one-placement-init', dest='one_placement_init', action='store_true')
    parser.add_argument('--null-placement-init', dest='null_placement_init', action='store_true')
    parser.add_argument('--init-best-pl', dest='init_best_pl', action='store_true')
    parser.add_argument('--one-shot-episodic-rew', dest='one_shot_episodic_rew', action='store_true')
    parser.add_argument('--ep-decay-start', type=float, default=1e3)
    parser.add_argument('--bl-n-rnds', type=int, default=1000)
    parser.add_argument('--rew-singlegpu', dest='rew_singlegpu', action='store_true')
    parser.add_argument('--rew-neigh-pl', dest='rew_neigh_pl', action='store_true')
    parser.add_argument('--supervised', dest='supervised', action='store_true')
    parser.add_argument('--use-min-runtime', dest='use_min_runtime', action='store_true')
    parser.add_argument('--discard-last-rnds', dest='discard_last_rnds', action='store_true')
    parser.add_argument('--turn-based-baseline', dest='turn_based_baseline', action='store_true')
    parser.add_argument('--dont-repeat-ff', action='store_true', dest='dont_repeat_ff')
    parser.add_argument('--small-nn', action='store_true', dest='small_nn')
    parser.add_argument('--dont-restore-softmax', dest='dont_restore_softmax', action='store_true')
    parser.add_argument('--restore-from', type=str, default=None)

    # report/log args
    parser.add_argument('--print-freq', type=int, default=50)
    parser.add_argument('--save-freq', type=int, default=100)
    parser.add_argument('--eval-freq', type=int, default=999)
    parser.add_argument('--log-tb-workers', dest='log_tb_workers', action='store_true')
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--debug-verbose', dest='debug_verbose', action='store_true')
    parser.add_argument('--disamb-pl', dest='disamb_pl', action='store_true')
    parser.add_argument('--eval', type=str, default=None)
    parser.add_argument('--simplify-tf-rew-model', action='store_true', dest='simplify_tf_rew_model')
    parser.add_argument('--log-runtime', dest='log_runtime', action='store_true')
    parser.add_argument('--use-new-sim', action='store_true', dest='use_new_sim')
    parser.add_argument('--gen-profile-timeline', dest='gen_profile_timeline', action='store_true')
    parser.add_argument('--mem-penalty', type=float, default=0.)
    parser.add_argument('--max-mem', type=float, default=11., help='Default Max Memory of GPU (in GB)')
    parser.add_argument('--max-runtime-mem-penalized', type=float, default=10.,
                        help='Instantaneous runtime of the placement after adding the memory penalty has to be lower than this number. Note that improvement in this memory penalized runtime metric is used to compute intermediate rewards')

    # dist training params
    parser.add_argument('--use-threads', dest='use_threads', action='store_true')
    parser.add_argument('--scale-norm', dest='scale_norm', action='store_true')
    parser.add_argument('--dont-share-classifier', action='store_true', dest='dont_share_classifier')
    parser.add_argument('--use-gpus', type=str, nargs='+', default=None)
    parser.add_argument('--eval-on-transfer', type=int, default=None,
                        help='Number of episodes to transfer train before reporting eval runtime')
    parser.add_argument('--normalize-aggs', dest='normalize_aggs', action='store_true')
    parser.add_argument('--bn-pre-classifier', dest='bn_pre_classifier', action='store_true')
    parser.add_argument('--bs', type=int, default=None)
    parser.add_argument('--num-children', type=int, default=1)
    parser.add_argument('--disable-profiling', action='store_true', dest='disable_profiling')
    parser.add_argument('--n-async-sims', type=int, default=None)
    parser.add_argument('--baseline-mask', type=int, nargs='+', default=None)
    parser.add_argument('--n-workers', type=int, default=1)
    parser.add_argument('--node-traversal-order', default='topo', help='Options: topo, random')
    parser.add_argument('--prune-final-size', type=int, default=None)
    parser.add_argument('--dont-sim-mem', dest='dont_sim_mem', action='store_true')

    parser.add_argument('--remote-async-addrs', type=str, default=None, nargs='+')
    parser.add_argument('--remote-async-start-ports', type=int, default=None, nargs='+')
    parser.add_argument('--remote-async-n-sims', type=int, default=None, nargs='+')
    parser.add_argument('--local-prefix', type=str, default=None)
    parser.add_argument('--remote-prefix', type=str, default=None)
    parser.add_argument('--shuffle-gpu-order', dest='shuffle_gpu_order', action='store_true')
    # parser.add_argument('--xml-file', type=str, required=True, help='Path to the XML file')

    with open('./config/config.txt', 'r') as file:
        args = file.read().strip().split()
    sys.argv[1:] = args

    args, unknown = parser.parse_known_args()

    # assert args.dont_repeat_ff

    if args.one_shot_episodic_rew and args.n_async_sims is not None:
        raise Exception('Input setting leads to deadlock')

    if args.eval_freq % 10 == 0:
        print('Eval freq cannot be divisible by 10')
        sys.exit(0)

    for option in unknown:
        for i in range(len(option)):
            if option[i] != '-':
                break
        if i > 0:
            option = option[i:].replace('-', '_')
            if option not in model.rl_params.args.__dict__:
                print(option)
                # pass
                raise Exception("Passed unknown option in dict : %s" % option)

    if args.use_gpus is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = ' '.join(args.use_gpus)

    # if args.eval_on_transfer is not None:
    # ProgressivePlacerTest().benchmark_policy(args.__dict__)
    # core = ov.Core()
    # devices = core.available_devices
    # print(f'devices {devices}')
    
    if args.n_workers > 1:
        ProgressivePlacerTest().mul_graphs(args.__dict__)
    else:
        # start_time = time()
        if(args.__dict__['dataset_folder'] is not None):
            r = []
            for root, dirs, files in os.walk(args.__dict__['dataset_folder']):
                for name in files:
                    args.__dict__['dataset'] = root.split("/")[-1]
                    args.__dict__['pickled_inp_file'] = [os.path.join(root, name)]
                    ProgressivePlacerTest().test(args.__dict__)
                    # r.append(os.path.join(root, name))

        ProgressivePlacerTest().test(args.__dict__)
        # print("Test time (minutes): ", (time() - start_time) / 60)
