import yaml
import copy
from ofa.utils import download_url
from ofa.model_zoo import ofa_net


class LatencyEstimator(object):

    def __init__(self, url):
        fname=url
        if isinstance(fname, dict):
            self.lut = fname
        else:
            with open(fname, 'r') as fp:
                self.lut = yaml.load(fp)

    @staticmethod
    def repr_shape(shape):
        if isinstance(shape, (list, tuple)):
            return 'x'.join(str(_) for _ in shape)
        elif isinstance(shape, str):
            return shape
        else:
            return TypeError

    def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
              se=None, h_swish=None):
        infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]

        if l_type in ('expanded_conv',):
            assert None not in (mid, ks, stride, id_skip, se, h_swish)
            infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
                      'se:%d' % se, 'hs:%d' % h_swish]
        key = '-'.join(infos)

        self.query_log.append((key, self.lut[key]['mean']))
        return self.lut[key]['mean']

    def predict_network_latency(self, net, image_size=224):
        self.query_log = []
        predicted_latency = 0
        # first conv
        predicted_latency += self.query(
            'Conv', [image_size, image_size, 3],
            [(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
        )
        # blocks
        fsize = (image_size + 1) // 2
        for block in net.blocks:
            mb_conv = block.mobile_inverted_conv
            shortcut = block.shortcut

            if mb_conv is None:
                continue
            if shortcut is None:
                idskip = 0
            else:
                idskip = 1
            out_fz = int((fsize - 1) / mb_conv.stride + 1)
            block_latency = self.query(
                'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
                mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
                se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
            )
            predicted_latency += block_latency
            fsize = out_fz
        # final expand layer
        predicted_latency += self.query(
            'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
            [fsize, fsize, net.final_expand_layer.out_channels],
        )
        # global average pooling
        predicted_latency += self.query(
            'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
            [1, 1, net.final_expand_layer.out_channels],
        )
        # feature mix layer
        predicted_latency += self.query(
            'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
            [1, 1, net.feature_mix_layer.out_channels]
        )
        # classifier
        predicted_latency += self.query(
            'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
        )
        return predicted_latency

    def predict_network_latency_given_spec(self, spec, width=1.0):
        self.query_log = []
        
        if width == 1.0:
            wid = [16, 24, 40, 80, 112, 160, 960, 1280]
        elif width == 1.2:
            wid = [24, 32, 48, 96, 136, 192, 1152, 1536]


        image_size = spec['r'][0]
        predicted_latency = 0
        # first conv
        predicted_latency += self.query(
            'Conv', [image_size, image_size, 3],
            [(image_size + 1) // 2, (image_size + 1) // 2, wid[0]]
        )
        # blocks
        fsize = (image_size + 1) // 2
        # first block
        predicted_latency += self.query(
            'expanded_conv', [fsize, fsize, wid[0]], [fsize, fsize, wid[0]],
            mid=wid[0], ks=3, stride=1, id_skip=1, se=0, h_swish=0,
        )
        in_channel = wid[0]
        stride_stages = [2, 2, 2, 1, 2]

        act_stages = ['relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
        se_stages = [False, True, False, True, True]
        for i in range(20):
            stage = i // 4
            depth_max = spec['d'][stage]
            depth = i % 4 + 1
            if depth > depth_max:
                continue
            ks, e = spec['ks'][i], spec['e'][i]
            if i % 4 == 0:
                stride = stride_stages[stage]
                idskip = 0
            else:
                stride = 1
                idskip = 1
            out_channel = wid[stage+1]
            out_fz = int((fsize - 1) / stride + 1)

            mid_channel = round(in_channel * e)
            block_latency = self.query(
                'expanded_conv', [fsize, fsize, in_channel], [out_fz, out_fz, out_channel],
                mid=mid_channel, ks=ks, stride=stride, id_skip=idskip,
                se=1 if se_stages[stage] else 0, h_swish=1 if act_stages[stage] == 'h_swish' else 0,
            )
            predicted_latency += block_latency
            fsize = out_fz
            in_channel = out_channel
        # final expand layer
        predicted_latency += self.query(
            'Conv_1', [fsize, fsize, wid[5]],
            [fsize, fsize, wid[6]],
        )
        # global average pooling
        # predicted_latency += self.query(
        #     'AvgPool2D', [fsize, fsize, 1152],
        #     [1, 1, 1152],
       #  )
        # feature mix layer
        predicted_latency += self.query(
            'Conv_2', [1, 1, wid[6]],
            [1, 1, wid[7]]
        )
        # classifier
        predicted_latency += self.query(
            'Logits', [1, 1, wid[7]], [1000]
        )
        return predicted_latency


class LatencyTable:
    def __init__(self, device='note10', resolutions=(160, 176, 192, 208, 224)):
        self.latency_tables = {}

        for image_size in resolutions:
            self.latency_tables[image_size] = LatencyEstimator(
                url='https://hanlab.mit.edu/files/OnceForAll/tutorial/latency_table@%s/%d_lookup_table.yaml' % (
                    device, image_size)
            )
            print('Built latency table for image size: %d.' % image_size)

    def predict_efficiency(self, spec: dict):
        return self.latency_tables[spec['r'][0]].predict_network_latency_given_spec(spec)
