import os
import sys
import torch
import numpy
import random
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data.sampler import SubsetRandomSampler
from layers import SpikeNode, Interpolate


# ------------ Some Functions ------------------------------------------------------

def acc_sum(n):
    """
        compute 1+2+3+...+n
    """
    return (1 + n) * n / 2


def trans2str(trans):
    """
        transform BCs' placement into string in ResNet18
    """
    from_spike = None
    to_spike = None
    if 1 <= trans[0] <= 4:
        from_spike = 1
    elif 5 <= trans[0] <= 8:
        from_spike = 2
    elif 9 <= trans[0] <= 12:
        from_spike = 3
    elif 13 <= trans[0] <= 16:
        from_spike = 4
    else:
        print("The label of the spike node is out of range.")
        exit()
    if 1 <= trans[1] <= 4:
        to_spike = 1
    elif 5 <= trans[1] <= 8:
        to_spike = 2
    elif 9 <= trans[1] <= 12:
        to_spike = 3
    elif 13 <= trans[1] <= 16:
        to_spike = 4
    else:
        print("The label of the spike node is out of range.")
        exit()
    str_ = str(from_spike) + '-' + str(to_spike)
    return str_


# ------------ TET Loss ------------------------------------------------------------

def TET_loss(outputs, labels, criterion, means=1.0, lamb=0.05):
    out_put_list = []
    for output in outputs:
        out_put = output.unsqueeze(1)
        out_put_list.append(out_put)
    out_puts = torch.cat(out_put_list, dim=1)
    T = out_puts.size(1)
    Loss_es = 0
    for t in range(T):
        Loss_es += criterion(out_puts[:, t, ...], labels)
    Loss_es = Loss_es / T
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(out_puts).fill_(means)
        Loss_mmd = MMDLoss(out_puts, y)
    else:
        Loss_mmd = 0
    return (1 - lamb) * Loss_es + lamb * Loss_mmd


# ------------ Dataset Transforms---------------------------------------------------

class DVSCifar10(Dataset):
    """
        reference to https://github.com/Gus-Lab/temporal_efficient_training/blob/main/data_loaders.py
    """
    def __init__(self, root, train=True, transform=False, target_transform=None):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.resize = transforms.Resize(size=(48, 48))
        self.tensorx = transforms.ToTensor()
        self.imgx = transforms.ToPILImage()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data, target = torch.load(self.root + '/{}.pt'.format(index))
        new_data = []
        for t in range(data.size(-1)):
            new_data.append(self.tensorx(self.resize(self.imgx(data[..., t]))))
        data = torch.stack(new_data, dim=0)
        if self.transform:
            flip = random.random() > 0.5
            if flip:
                data = torch.flip(data, dims=(3,))
            off1 = random.randint(-5, 5)
            off2 = random.randint(-5, 5)
            data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))

        if self.target_transform is not None:
            target = self.target_transform(target)
        return data, target.long().squeeze(-1)

    def __len__(self):
        return len(os.listdir(self.root))


class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if self.Train:
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(self.train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(self.val_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if fname.endswith(".JPEG"):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, tgt


def build_dvscifar(path):
    """
        reference to https://github.com/Gus-Lab/temporal_efficient_training/blob/main/data_loaders.py
    """
    train_path = path + '/train'
    val_path = path + '/test'
    train_dataset = DVSCifar10(root=train_path, transform=True)
    val_dataset = DVSCifar10(root=val_path)

    return train_dataset, val_dataset


def build_tiny_imagenet():
    data_dir = '/data_smr/dataset/tiny_ImageNet/tiny-imagenet-200/'  # the path for tiny-imagenet-200 dataset
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    train_transforms = transforms.Compose([
        transforms.RandomCrop(64, padding=8),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])
    val_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])
    train_dataset = TinyImageNet(data_dir, train=True, transform=train_transforms)
    val_dataset = TinyImageNet(data_dir, train=False, transform=val_transforms)

    return train_dataset, val_dataset


def get_transforms(dataset_name, ):
    """
        training from scratch
    """

    train_dataset = None
    test_dataset = None

    if dataset_name == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='/data_smr/dataset/CIFAR10', train=True, download=True,
                                transform=transform_train)
        test_dataset = CIFAR10(root='/data_smr/dataset/CIFAR10', train=False, download=True,
                               transform=transform_test)

    elif dataset_name == 'CIFAR100':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_dataset = CIFAR100(root='/data_smr/dataset/CIFAR100/', train=True, download=True,
                                 transform=transform_train)
        test_dataset = CIFAR100(root='/data_smr/dataset/CIFAR100/', train=False, download=True,
                                transform=transform_test)

    elif dataset_name == 'CIFAR10DVS':
        train_dataset, test_dataset = build_dvscifar(path='/data_smr/dataset/cifar-dvs')

    elif dataset_name == 'Tiny_Imagenet':
        train_dataset, test_dataset = build_tiny_imagenet()

    return train_dataset, test_dataset


def get_train_val_loaders(args, dataset_name, search=False, if_nassnn=False):
    """
        for super-network training
    """

    train_dataset = None
    test_dataset = None

    if dataset_name == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='/data_smr/dataset/CIFAR10', train=True, download=True,
                                transform=transform_train)
        test_dataset = CIFAR10(root='/data_smr/dataset/CIFAR10', train=False, download=True,
                               transform=transform_test)

    elif dataset_name == 'CIFAR100':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_dataset = CIFAR100(root='/data_smr/dataset/CIFAR100/', train=True, download=True,
                                 transform=transform_train)
        test_dataset = CIFAR100(root='/data_smr/dataset/CIFAR100/', train=False, download=True,
                                transform=transform_test)

    elif dataset_name == 'CIFAR10DVS':
        train_dataset, test_dataset = build_dvscifar(path='/data_smr/dataset/cifar-dvs')

    elif dataset_name == 'Tiny_Imagenet':
        train_dataset, test_dataset = build_tiny_imagenet()

    if search:

        if dataset_name == 'Tiny_Imagenet':   # multi-train
            num_train = int(len(train_dataset) * 0.8)
            num_val = int(len(train_dataset) - num_train)
            train_data, val_data = torch.utils.data.random_split(train_dataset, [num_train, num_val])
            return train_dataset, train_data, val_data

        else:
            num_train = len(train_dataset)
            indices = list(range(num_train))
            split = int(numpy.floor(0.8 * num_train))  # D_train: D_val = 8: 2
            train_idx, valid_idx = indices[:split], indices[split:]

            train_data_loader = torch.utils.data.DataLoader(
                dataset=train_dataset, batch_size=args.batch_size,
                sampler=SubsetRandomSampler(train_idx), num_workers=4, drop_last=False, pin_memory=True)
            valid_data_loader = torch.utils.data.DataLoader(
                dataset=train_dataset, batch_size=args.batch_size,
                sampler=SubsetRandomSampler(valid_idx), num_workers=4, drop_last=False, pin_memory=True)

            return train_data_loader, valid_data_loader

    else:

        if dataset_name == 'Tiny_Imagenet':  # TODO: give loader in evo_algo
            return train_dataset, test_dataset

        else:
            if not if_nassnn:
                train_data_loader = torch.utils.data.DataLoader(
                    dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                    num_workers=4, drop_last=False, pin_memory=True)
                test_data_loader = torch.utils.data.DataLoader(
                    dataset=test_dataset, batch_size=args.batch_size, shuffle=False,
                    num_workers=4, drop_last=False, pin_memory=True)
            else:
                train_data_loader = torch.utils.data.DataLoader(
                    dataset=train_dataset, batch_size=args.batch_size * 2, shuffle=True,
                    num_workers=4, drop_last=False, pin_memory=True)
                test_data_loader = torch.utils.data.DataLoader(
                    dataset=test_dataset, batch_size=args.batch_size * 2, shuffle=False,
                    num_workers=4, drop_last=False, pin_memory=True)

            return train_data_loader, test_data_loader


# ------------------ Basic Blocks---------------------------------------------------

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BWBlock(nn.Module):
    """
        BackWard-Block for ResNet18
    """

    def __init__(self, transmission='4-4'):
        super(BWBlock, self).__init__()

        self.trans = transmission
        self.inter_polate = nn.Identity()
        self.conv = None
        self.bn = None

        if self.trans == '4-4':
            self.conv = conv1x1(in_planes=512, out_planes=512, stride=1)
            self.bn = nn.BatchNorm2d(512)

        elif self.trans == '4-3':
            self.inter_polate = Interpolate(scale_factor=2, mode='nearest')
            self.conv = conv1x1(in_planes=512, out_planes=256, stride=1)
            self.bn = nn.BatchNorm2d(256)

        elif self.trans == '4-2':
            self.inter_polate = Interpolate(scale_factor=4, mode='nearest')
            self.conv = conv1x1(in_planes=512, out_planes=128, stride=1)
            self.bn = nn.BatchNorm2d(128)

        elif self.trans == '4-1':
            self.inter_polate = Interpolate(scale_factor=8, mode='nearest')
            self.conv = conv1x1(in_planes=512, out_planes=64, stride=1)
            self.bn = nn.BatchNorm2d(64)

        elif self.trans == '3-3':
            self.conv = conv1x1(in_planes=256, out_planes=256, stride=1)
            self.bn = nn.BatchNorm2d(256)

        elif self.trans == '3-2':
            self.inter_polate = Interpolate(scale_factor=2, mode='nearest')
            self.conv = conv1x1(in_planes=256, out_planes=128, stride=1)
            self.bn = nn.BatchNorm2d(128)

        elif self.trans == '3-1':
            self.inter_polate = Interpolate(scale_factor=4, mode='nearest')
            self.conv = conv1x1(in_planes=256, out_planes=64, stride=1)
            self.bn = nn.BatchNorm2d(64)

        elif self.trans == '2-2':
            self.conv = conv1x1(in_planes=128, out_planes=128, stride=1)
            self.bn = nn.BatchNorm2d(128)

        elif self.trans == '2-1':
            self.inter_polate = Interpolate(scale_factor=2, mode='nearest')
            self.conv = conv1x1(in_planes=128, out_planes=64, stride=1)
            self.bn = nn.BatchNorm2d(64)

        elif self.trans == '1-1':
            self.conv = conv1x1(in_planes=64, out_planes=64, stride=1)
            self.bn = nn.BatchNorm2d(64)

        else:
            print("will be added...")
            exit()

    def forward(self, input):
        output = self.inter_polate(input)
        output = self.conv(output)
        output = self.bn(output)
        return output


class BasicBWBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1, downsample=None, ):
        super(BasicBWBlock, self).__init__()

        self.conv1 = conv3x3(in_planes=inplanes, out_planes=planes, stride=stride)
        self.bn1 = nn.BatchNorm2d(num_features=planes)
        self.spike1 = SpikeNode(vth=1.0, tau=0.5, gamma=1.0, v_reset=0.)

        self.conv2 = conv3x3(in_planes=planes, out_planes=planes)
        self.bn2 = nn.BatchNorm2d(num_features=planes)
        self.spike2 = SpikeNode(vth=1.0, tau=0.5, gamma=1.0, v_reset=0.)

        self.downsample = downsample
        self.stride = stride

    def forward(self, inp, bw1_value=0., bw2_value=0.):
        identity = inp

        out = self.conv1(inp)
        out = self.bn1(out)
        out += bw1_value
        spike1_value = self.spike1(out)

        out = self.conv2(spike1_value)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(inp)

        out += identity
        out += bw2_value
        out = self.spike2(out)

        return out, spike1_value


# ------------------ ResNet18 ------------------------------------------------------

class ResNet18_bw_search(nn.Module):

    """
        ResNet18 with backward connections
    """

    def __init__(self, args, fb_mat, ):
        super(ResNet18_bw_search, self).__init__()
        self.T = args.T
        self.fb_mat = fb_mat
        self.num_class = args.num_class
        self.static = args.if_static
        self.size = args.size

        # get the information of the backward connections
        self.transmission_list = []
        for idx_row, row in enumerate(fb_mat):
            for idx_col, col in enumerate(row):
                if col == 1:
                    transmission = (idx_row + 1, idx_col + 1)
                    self.transmission_list.append(transmission)

        # choose BWBlocks according to the fb_mat
        self.BWop_all = nn.ModuleList()
        for trans in self.transmission_list:
            from_spike = None
            to_spike = None
            if 1 <= trans[0] <= 4:
                from_spike = 1
            elif 5 <= trans[0] <= 8:
                from_spike = 2
            elif 9 <= trans[0] <= 12:
                from_spike = 3
            elif 13 <= trans[0] <= 16:
                from_spike = 4
            else:
                print("The label of the spike node is out of range")
                exit()
            if 1 <= trans[1] <= 4:
                to_spike = 1
            elif 5 <= trans[1] <= 8:
                to_spike = 2
            elif 9 <= trans[1] <= 12:
                to_spike = 3
            elif 13 <= trans[1] <= 16:
                to_spike = 4
            else:
                print("The label of the spike node is out of range")
                exit()
            str_ = str(from_spike) + '-' + str(to_spike)
            self.BWop_all.append(BWBlock(transmission=str_))

        # stem layer
        if self.static:
            self.conv_stem = conv3x3(in_planes=3, out_planes=64, stride=1)
        else:
            self.conv_stem = conv3x3(in_planes=2, out_planes=64, stride=1)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.spike1 = SpikeNode(vth=1.0, tau=0.5, gamma=1.0, v_reset=0.)

        # block1
        self.block1_1 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )
        self.block1_2 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )

        # block2
        self.block2_downsample = conv1x1(in_planes=64, out_planes=128, stride=2)
        self.block2_downsample_bn = nn.BatchNorm2d(num_features=128)
        self.block2_downsample_seq = nn.Sequential(self.block2_downsample, self.block2_downsample_bn)
        self.block2_1 = BasicBWBlock(inplanes=64, planes=128, stride=2, downsample=self.block2_downsample_seq, )
        self.block2_2 = BasicBWBlock(inplanes=128, planes=128, stride=1, downsample=None, )

        # block3
        self.block3_downsample = conv1x1(in_planes=128, out_planes=256, stride=2)
        self.block3_downsample_bn = nn.BatchNorm2d(num_features=256)
        self.block3_downsample_seq = nn.Sequential(self.block3_downsample, self.block3_downsample_bn)
        self.block3_1 = BasicBWBlock(inplanes=128, planes=256, stride=2, downsample=self.block3_downsample_seq, )
        self.block3_2 = BasicBWBlock(inplanes=256, planes=256, stride=1, downsample=None, )

        # block4
        self.block4_downsample = conv1x1(in_planes=256, out_planes=512, stride=2)
        self.block4_downsample_bn = nn.BatchNorm2d(num_features=512)
        self.block4_downsample_seq = nn.Sequential(self.block4_downsample, self.block4_downsample_bn)
        self.block4_1 = BasicBWBlock(inplanes=256, planes=512, stride=2, downsample=self.block4_downsample_seq, )
        self.block4_2 = BasicBWBlock(inplanes=512, planes=512, stride=1, downsample=None, )

        # average pool
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # fc
        self.fc = nn.Linear(512, self.num_class)

        # backward connections
        self.bw = [0.] * self.size

        # initialize parameters in model
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, inp):

        self.reset()
        voltage = 0
        vol = []  # prepare for TET loss

        for t in range(self.T):

            if self.static:
                out = self.conv_stem(inp)
            else:
                out = self.conv_stem(inp[:, t, :, :, :])

            out = self.bn1(out)
            out = self.spike1(out)

            spike2, spike1 = self.block1_1(out, self.bw[0], self.bw[1])
            spike4, spike3 = self.block1_2(spike2, self.bw[2], self.bw[3])

            spike6, spike5 = self.block2_1(spike4, self.bw[4], self.bw[5])
            spike8, spike7 = self.block2_2(spike6, self.bw[6], self.bw[7])

            spike10, spike9 = self.block3_1(spike8, self.bw[8], self.bw[9])
            spike12, spike11 = self.block3_2(spike10, self.bw[10], self.bw[11])

            spike14, spike13 = self.block4_1(spike12, self.bw[12], self.bw[13])
            spike16, spike15 = self.block4_2(spike14, self.bw[14], self.bw[15])

            out = self.avgpool(spike16)
            out = out.view(inp.size(0), -1)

            out = self.fc(out)
            voltage += out
            vol.append(out)

            # bw
            spike_dict = {
                '1': spike1, '2': spike2, '3': spike3, '4': spike4,
                '5': spike5, '6': spike6, '7': spike7, '8': spike8,
                '9': spike9, '10': spike10, '11': spike11, '12': spike12,
                '13': spike13, '14': spike14, '15': spike15, '16': spike16,
            }

            self.bw = [0.] * self.size

            for idx, transs in enumerate(self.transmission_list):
                self.bw[transs[1] - 1] += self.BWop_all[idx](spike_dict[str(transs[0])])

        voltage_avg = voltage / self.T

        return voltage_avg, vol

    def reset(self):
        self.bw = [0.] * self.size
        for m in self.modules():
            if 'SpikeNode' in str(type(m)):
                m.reset()


class ResNet18_super(nn.Module):
    """
        super-network based on ResNet18
    """

    def __init__(self, args):
        super(ResNet18_super, self).__init__()
        self.T = args.T
        self.num_class = args.num_class
        self.static = args.if_static
        self.size = args.size

        # stem layer
        if self.static:
            self.conv_stem = conv3x3(in_planes=3, out_planes=64, stride=1)
        else:
            self.conv_stem = conv3x3(in_planes=2, out_planes=64, stride=1)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.spike1 = SpikeNode(vth=1.0, tau=0.5, gamma=1.0, v_reset=0.)

        # block1
        self.block1_1 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )
        self.block1_2 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )

        # block2
        self.block2_downsample = conv1x1(in_planes=64, out_planes=128, stride=2)
        self.block2_downsample_bn = nn.BatchNorm2d(num_features=128)
        self.block2_downsample_seq = nn.Sequential(self.block2_downsample, self.block2_downsample_bn)
        self.block2_1 = BasicBWBlock(inplanes=64, planes=128, stride=2, downsample=self.block2_downsample_seq, )
        self.block2_2 = BasicBWBlock(inplanes=128, planes=128, stride=1, downsample=None, )

        # block3
        self.block3_downsample = conv1x1(in_planes=128, out_planes=256, stride=2)
        self.block3_downsample_bn = nn.BatchNorm2d(num_features=256)
        self.block3_downsample_seq = nn.Sequential(self.block3_downsample, self.block3_downsample_bn)
        self.block3_1 = BasicBWBlock(inplanes=128, planes=256, stride=2, downsample=self.block3_downsample_seq, )
        self.block3_2 = BasicBWBlock(inplanes=256, planes=256, stride=1, downsample=None, )

        # block4
        self.block4_downsample = conv1x1(in_planes=256, out_planes=512, stride=2)
        self.block4_downsample_bn = nn.BatchNorm2d(num_features=512)
        self.block4_downsample_seq = nn.Sequential(self.block4_downsample, self.block4_downsample_bn)
        self.block4_1 = BasicBWBlock(inplanes=256, planes=512, stride=2, downsample=self.block4_downsample_seq, )
        self.block4_2 = BasicBWBlock(inplanes=512, planes=512, stride=1, downsample=None, )

        # average pool
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # fc
        self.fc = nn.Linear(512, self.num_class)

        # all possible backward connections
        self.bw = [0.] * self.size
        self.BWop_all = nn.ModuleList()
        self.transmission_list = []
        for row in range(1, self.size + 1):
            for col in range(1, row + 1):
                transmission = (row, col)
                self.transmission_list.append(transmission)
        for trans in self.transmission_list:
            str_ = trans2str(trans)
            self.BWop_all.append(BWBlock(transmission=str_))

        # initialize parameters in model
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, inp):

        backward_str_list = self.sampling()
        self.reset()
        voltage = 0
        vol = []

        for t in range(self.T):

            if self.static:
                out = self.conv_stem(inp)
            else:
                out = self.conv_stem(inp[:, t, :, :, :])
            out = self.bn1(out)
            out = self.spike1(out)

            spike2, spike1 = self.block1_1(out, self.bw[0], self.bw[1])
            spike4, spike3 = self.block1_2(spike2, self.bw[2], self.bw[3])

            spike6, spike5 = self.block2_1(spike4, self.bw[4], self.bw[5])
            spike8, spike7 = self.block2_2(spike6, self.bw[6], self.bw[7])

            spike10, spike9 = self.block3_1(spike8, self.bw[8], self.bw[9])
            spike12, spike11 = self.block3_2(spike10, self.bw[10], self.bw[11])

            spike14, spike13 = self.block4_1(spike12, self.bw[12], self.bw[13])
            spike16, spike15 = self.block4_2(spike14, self.bw[14], self.bw[15])

            out = self.avgpool(spike16)
            out = out.view(inp.size(0), -1)

            out = self.fc(out)
            voltage += out
            vol.append(out)

            # bw
            spike_dict = {
                '1': spike1, '2': spike2, '3': spike3, '4': spike4,
                '5': spike5, '6': spike6, '7': spike7, '8': spike8,
                '9': spike9, '10': spike10, '11': spike11, '12': spike12,
                '13': spike13, '14': spike14, '15': spike15, '16': spike16,
            }

            self.bw = [0.] * self.size

            for transs in backward_str_list:
                iidd = int(acc_sum(transs[0] - 1) + transs[1] - 1)
                self.bw[transs[1] - 1] += self.BWop_all[iidd](spike_dict[str(transs[0])])

        voltage_avg = voltage / self.T

        return voltage_avg, vol

    def reset(self):
        self.bw = [0.] * self.size
        for m in self.modules():
            if 'SpikeNode' in str(type(m)):
                m.reset()

    def sampling(self):
        bw_list = []
        num_path = random.choice(range(1, 3))
        if num_path == 1:
            x1 = random.choice(range(self.size))
            x2 = random.choice(range(self.size))
            if x1 > x2:
                bw_list.append((x1 + 1, x2 + 1))
            else:
                bw_list.append((x2 + 1, x1 + 1))

        elif num_path == 2:
            x1 = random.choice(range(self.size))
            x2 = random.choice(range(self.size))
            x3 = random.choice(range(self.size))
            x4 = random.choice(range(self.size))
            if x1 > x2:
                bw_list.append((x1 + 1, x2 + 1))
            else:
                bw_list.append((x2 + 1, x1 + 1))
            if x3 > x4:
                bw_list.append((x3 + 1, x4 + 1))
            else:
                bw_list.append((x4 + 1, x3 + 1))

            if bw_list[0] == bw_list[1]:
                del bw_list[1]
        else:
            raise NotImplementedError

        return bw_list


class ResNet18_child(nn.Module):
    """
        child-network based on ResNet18 for validation accuracy in the evolutionary algorithm
    """

    def __init__(self, args, fb_mat):
        super(ResNet18_child, self).__init__()
        self.T = args.T
        self.fb_mat = fb_mat
        self.num_class = args.num_class
        self.static = args.if_static
        self.size = args.size

        self.tranlist = []
        for idx_row, row in enumerate(fb_mat):
            for idx_col, col in enumerate(row):
                if col == 1:
                    transmission = (idx_row + 1, idx_col + 1)
                    self.tranlist.append(transmission)

        # stem layer
        if self.static:
            self.conv_stem = conv3x3(in_planes=3, out_planes=64, stride=1)
        else:
            self.conv_stem = conv3x3(in_planes=2, out_planes=64, stride=1)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.spike1 = SpikeNode(vth=1.0, tau=0.5, gamma=1.0, v_reset=0.)

        # block1
        self.block1_1 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )
        self.block1_2 = BasicBWBlock(inplanes=64, planes=64, stride=1, downsample=None, )

        # block2
        self.block2_downsample = conv1x1(in_planes=64, out_planes=128, stride=2)
        self.block2_downsample_bn = nn.BatchNorm2d(num_features=128)
        self.block2_downsample_seq = nn.Sequential(self.block2_downsample, self.block2_downsample_bn)
        self.block2_1 = BasicBWBlock(inplanes=64, planes=128, stride=2, downsample=self.block2_downsample_seq, )
        self.block2_2 = BasicBWBlock(inplanes=128, planes=128, stride=1, downsample=None, )

        # block3
        self.block3_downsample = conv1x1(in_planes=128, out_planes=256, stride=2)
        self.block3_downsample_bn = nn.BatchNorm2d(num_features=256)
        self.block3_downsample_seq = nn.Sequential(self.block3_downsample, self.block3_downsample_bn)
        self.block3_1 = BasicBWBlock(inplanes=128, planes=256, stride=2, downsample=self.block3_downsample_seq, )
        self.block3_2 = BasicBWBlock(inplanes=256, planes=256, stride=1, downsample=None, )

        # block4
        self.block4_downsample = conv1x1(in_planes=256, out_planes=512, stride=2)
        self.block4_downsample_bn = nn.BatchNorm2d(num_features=512)
        self.block4_downsample_seq = nn.Sequential(self.block4_downsample, self.block4_downsample_bn)
        self.block4_1 = BasicBWBlock(inplanes=256, planes=512, stride=2, downsample=self.block4_downsample_seq, )
        self.block4_2 = BasicBWBlock(inplanes=512, planes=512, stride=1, downsample=None, )

        # average pool
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # fc
        self.fc = nn.Linear(512, self.num_class)

        # all possible backward connections
        self.bw = [0.] * self.size
        self.BWop_all = nn.ModuleList()
        self.transmission_list = []
        for row in range(1, self.size + 1):
            for col in range(1, row + 1):
                transmission = (row, col)
                self.transmission_list.append(transmission)
        for trans in self.transmission_list:
            str_ = trans2str(trans)
            self.BWop_all.append(BWBlock(transmission=str_))

        # initialize parameters in model
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, inp):

        self.reset()
        voltage = 0
        vol = []

        for t in range(self.T):

            if self.static:
                out = self.conv_stem(inp)
            else:
                out = self.conv_stem(inp[:, t, :, :, :])
            out = self.bn1(out)
            out = self.spike1(out)

            spike2, spike1 = self.block1_1(out, self.bw[0], self.bw[1])
            spike4, spike3 = self.block1_2(spike2, self.bw[2], self.bw[3])

            spike6, spike5 = self.block2_1(spike4, self.bw[4], self.bw[5])
            spike8, spike7 = self.block2_2(spike6, self.bw[6], self.bw[7])

            spike10, spike9 = self.block3_1(spike8, self.bw[8], self.bw[9])
            spike12, spike11 = self.block3_2(spike10, self.bw[10], self.bw[11])

            spike14, spike13 = self.block4_1(spike12, self.bw[12], self.bw[13])
            spike16, spike15 = self.block4_2(spike14, self.bw[14], self.bw[15])

            out = self.avgpool(spike16)
            out = out.view(inp.size(0), -1)

            out = self.fc(out)
            voltage += out
            vol.append(out)

            spike_dict = {
                '1': spike1, '2': spike2, '3': spike3, '4': spike4,
                '5': spike5, '6': spike6, '7': spike7, '8': spike8,
                '9': spike9, '10': spike10, '11': spike11, '12': spike12,
                '13': spike13, '14': spike14, '15': spike15, '16': spike16,
            }

            self.bw = [0.] * self.size

            for transs in self.tranlist:
                iidd = int(acc_sum(transs[0] - 1) + transs[1] - 1)
                self.bw[transs[1] - 1] += self.BWop_all[iidd](spike_dict[str(transs[0])])

        voltage_avg = voltage / self.T

        return voltage_avg, vol

    def reset(self):
        self.bw = [0.] * self.size
        for m in self.modules():
            if 'SpikeNode' in str(type(m)):
                m.reset()
