# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nasbench2_ops import *


def gen_searchcell_mask_from_arch_str(arch_str):
    nodes = arch_str.split("+")
    nodes = [node[1:-1].split("|") for node in nodes]
    nodes = [[op_and_input.split("~") for op_and_input in node] for node in nodes]

    keep_mask = []
    for curr_node_idx in range(len(nodes)):
        for prev_node_idx in range(curr_node_idx + 1):
            _op = [
                edge[0]
                for edge in nodes[curr_node_idx]
                if int(edge[1]) == prev_node_idx
            ]
            assert (
                len(_op) == 1
            ), "The arch string does not follow the assumption of 1 connection between two nodes."
            for _op_name in OPS.keys():
                keep_mask.append(_op[0] == _op_name)
    return keep_mask


def get_model_from_arch_str(arch_str, num_classes, use_bn=True, init_channels=16):
    keep_mask = gen_searchcell_mask_from_arch_str(arch_str)
    net = NAS201Model(
        arch_str=arch_str,
        num_classes=num_classes,
        use_bn=use_bn,
        keep_mask=keep_mask,
        stem_ch=init_channels,
    )
    return net


def get_super_model(num_classes, use_bn=True):
    net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn)
    return net


class NAS201Model(nn.Module):
    def __init__(self, arch_str, num_classes, use_bn=True, keep_mask=None, stem_ch=16):
        super(NAS201Model, self).__init__()
        self.arch_str = arch_str
        self.num_classes = num_classes
        self.use_bn = use_bn

        self.stem = stem(out_channels=stem_ch, use_bn=use_bn)
        self.stack_cell1 = nn.Sequential(
            *[
                SearchCell(
                    in_channels=stem_ch,
                    out_channels=stem_ch,
                    stride=1,
                    affine=False,
                    track_running_stats=False,
                    use_bn=use_bn,
                    keep_mask=keep_mask,
                )
                for i in range(5)
            ]
        )
        self.reduction1 = reduction(in_channels=stem_ch, out_channels=stem_ch * 2)
        self.stack_cell2 = nn.Sequential(
            *[
                SearchCell(
                    in_channels=stem_ch * 2,
                    out_channels=stem_ch * 2,
                    stride=1,
                    affine=False,
                    track_running_stats=False,
                    use_bn=use_bn,
                    keep_mask=keep_mask,
                )
                for i in range(5)
            ]
        )
        self.reduction2 = reduction(in_channels=stem_ch * 2, out_channels=stem_ch * 4)
        self.stack_cell3 = nn.Sequential(
            *[
                SearchCell(
                    in_channels=stem_ch * 4,
                    out_channels=stem_ch * 4,
                    stride=1,
                    affine=False,
                    track_running_stats=False,
                    use_bn=use_bn,
                    keep_mask=keep_mask,
                )
                for i in range(5)
            ]
        )
        self.top = top(in_dims=stem_ch * 4, num_classes=num_classes, use_bn=use_bn)

    def forward(self, x):
        x = self.stem(x)

        x = self.stack_cell1(x)
        x = self.reduction1(x)

        x = self.stack_cell2(x)
        x = self.reduction2(x)

        x = self.stack_cell3(x)

        x = self.top(x)
        return x

    def get_prunable_copy(self, bn=False):
        model_new = get_model_from_arch_str(self.arch_str, self.num_classes, use_bn=bn)

        # TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
        # it is only required to maintain initialization -- maybe init after get_punable_copy?
        model_new.load_state_dict(self.state_dict(), strict=False)
        model_new.train()

        return model_new


def get_arch_str_from_model(net):
    search_cell = net.stack_cell1[0].options
    keep_mask = net.stack_cell1[0].keep_mask
    num_nodes = net.stack_cell1[0].num_nodes

    nodes = []
    idx = 0
    for curr_node in range(num_nodes - 1):
        edges = []
        for prev_node in range(curr_node + 1):  # n-1 prev nodes
            for _op_name in OPS.keys():
                if keep_mask[idx]:
                    edges.append(f"{_op_name}~{prev_node}")
                idx += 1
        node_str = "|".join(edges)
        node_str = f"|{node_str}|"
        nodes.append(node_str)
    arch_str = "+".join(nodes)
    return arch_str


if __name__ == "__main__":
    arch_str = "|nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"

    n = get_model_from_arch_str(arch_str=arch_str, num_classes=10)
    print(n.stack_cell1[0])

    arch_str2 = get_arch_str_from_model(n)
    print(arch_str)
    print(arch_str2)
    print(f"Are the two arch strings same? {arch_str == arch_str2}")
