
import sys
import torch
import pickle
import torch.nn as nn

import math
from typing import Type, Any, Callable, Union, List, Optional

import torch
import torch.nn as nn
from torch import Tensor

try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url
from .utils import _log_api_usage_once

from functools import partial
from typing import Dict, Type, Any, Callable, Union, List, Optional

negative_a = 0.0
positive_a = 0.001
scores = []


model_urls = {
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
}

negative_a = 0.0
positive_a = 0.001


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.inplanes = inplanes
        self.planes = planes
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.S1 = nn.Linear(inplanes * 9 * planes, planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.S2 = nn.Linear(planes * 9 * planes, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, ins):
        x = ins[0]
        scores = ins[1]
        is_score_training = ins[2]

        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        scores1 = torch.exp(self.S1(torch.flatten(self.conv1.weight)))
        scores11 = scores[0]
        scores11 = scores11.to(out.device)
        if not is_score_training:
            scores11 = torch.where(scores11 > 0.5, 1.0 + (scores11) * negative_a, scores11 * negative_a)

        out = torch.mul(
            torch.reshape(scores11.repeat(out.size()[0] * out.size()[-1] * out.size()[-2], 1),
                          out.permute(0, 2, 3, 1).size()),
            out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        scores2 = torch.exp(self.S2(torch.flatten(self.conv2.weight)))
        scores21 = scores[1]
        scores21 = scores21.to(out.device)
        if not is_score_training:
            scores21 = torch.where(scores21 > 0.5, 1.0 + (scores21) * negative_a, scores21 * negative_a)

        out = torch.mul(
            torch.reshape(scores21.repeat(out.size()[0] * out.size()[-1] * out.size()[-2], 1),
                          out.permute(0, 2, 3, 1).size()),
            out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


        return (out, scores[2:], is_score_training)


class CifarResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(CifarResNet, self).__init__()
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.S1 = nn.Linear(3 * 9 * 16, 16)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, is_score_training=False):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        scores1 = torch.exp(self.S1(torch.flatten(self.conv1.weight)))

        scores11 = scores[0]
        scores11 = scores11.to(x.device)
        # print(scores11.device, x.device)
        if not is_score_training:
            scores11 = torch.where(scores11 > 0.5, 1.0 + (scores11) * negative_a, scores11 * negative_a)

        x = torch.mul(
            torch.reshape(scores11.repeat(x.size()[0] * x.size()[-1] * x.size()[-2], 1),
                          x.permute(0, 2, 3, 1).size()),
            x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


        outs = self.layer1((x, scores[1:], is_score_training))
        outs = self.layer2(outs)
        outs = self.layer3(outs)

        x = outs[0]
        scores2 = outs[1]
        is_score_training = outs[2]

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return scores, x


class BasicBlockVanilla(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlockVanilla, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class CifarResNetVanilla(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(CifarResNetVanilla, self).__init__()
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, is_score_training=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return [], x


def cifar10_resnet56(scores_path: str, **kwargs: Any) -> CifarResNet:
    device = torch.device('cpu')
    with open(scores_path, 'rb') as fp:
        scores1 = pickle.load(fp)
    for s in scores1:
        scores.append(s.to(device))

    model = CifarResNet(BasicBlock, [9] * 3)

    return model

def cifar10_resnet110(scores_path: str, **kwargs: Any) -> CifarResNet:
    device = torch.device('cpu')
    with open(scores_path, 'rb') as fp:
        scores1 = pickle.load(fp)
    for s in scores1:
        scores.append(s.to(device))

    model = CifarResNet(BasicBlock, [18] * 3)

    return model


#         )
