'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, want_linear=True):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.dimE = 512*block.expansion
        self.linear = nn.Linear(self.dimE, num_classes) if want_linear else lambda x: x

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4) # (input, kernal_size)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out









class Sentinel(ResNet):
    def __init__(self, num_classes=10):
        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes, False) # ResNet18
        self.xis = []           # elem type is 1-D tensor
        self.Ls = []            # elem type is float tensor
        self.Llevels = []       # elem type is int, not tensor
        self.roster = torch.zeros(num_classes+1, dtype=int, device=DEVICE)
        self.L_ini = 1.
        self.L_decay = 0.5
        self.test_acc = 0.
        self.best_acc = 0.


    def new(self, label, loc, Llevel, verbose=False):
        # one sentinel has two portraits: xi (has shape [1, dimE]) and L (float tensor)
        # xis: patially sorted list accoding to L (not sorted actually)
        # Ls: a list containing L; len(Ls)==len(Llevels)==len(xis)==Nnel
        # roster: a list of length num_classes+1 , roster[ic],roster[ic+1] meaning the starting/end pointer of every class
        # label: <int>

        if verbose: print(f'Class:  {label}   ||   Llevel: {Llevel}')
        start = self.roster[label]
        self.xis.insert(start, loc.detach().clone())
        self.Ls.insert(start, torch.tensor(self.L_decay**Llevel * self.L_ini, device=DEVICE) )
        self.Llevels.insert(start, torch.tensor(Llevel, dtype=int, device=DEVICE))
        self.roster[label+1:] += 1
        return



    def learn(self, emb, acti, label):
        # all inputs don't have batch dimention
        # one-by-one process
        # emb :     [dimE, ]
        # acti:     [Nnel, ]
        # label:    scalar tensor, type=int

        verbose=0
        maxv, inel = acti.max(dim=0)  # both are scalar tensor
        if maxv==0:     # no sentinel can detect this sample
            self.new(label, emb, 0, verbose)
            return 1
        elif inel<self.roster[label] or inel>=self.roster[label+1]:   # sample closer to the sentinel of wrong class
            newlevel = self.Llevels[inel]+1
            self.new(label, emb, newlevel, verbose)
            return 2
        else:   # sample is correctly detected; do nothing
            if verbose: print(f'Class:  {label}   ||    --- Nothing new ---')
            return 0

    def hashmap(self, activations):
        # activations -> class  mapping
        # input: 
        #     activations:  [batch_size, Nnel]
        # output:          
        #     classBeliefs:      [batch_size, num_classes]

        # below is supposed to run on parallel, but it's difficult to implement parallel running since the Nnel_classi are different.

        classBeliefs = []
        for ic in range(self.num_classes):
            classBeliefs.append( self.belief_classi_structured(activations, ic) )  # [batch_size]
 
        return torch.stack(classBeliefs).T

    def report_activations(self, x):
        # input: 
        #     x:          [batch_size, dimE]
        #     xis:        [Nnel, dimE]
        # output:         
        #     classBeliefs:    [batch_size, num_classes]
        #     activations:    [batch_size, Nnel]
        if len(self.xis)==0:
            return torch.zeros(len(x), self.num_classes, device=DEVICE)

        dist = distance_Nnel(self.assemble(), x)   # [batch_size, Nnel]
        activations = F.relu( 1 - dist/self.collect_Ls()[0] ) # same shape

        # plt.close('all')
        # plt.plot(activations[:3].T.detach())
        # plt.savefig('wIns.pdf',bbox_inches='tight')
        
        return activations

























    def train_batch(self, x, labels):

        embbeddings = super().forward(x) # shape: [batch_size, self.dimE]
        activations = self.report_activations(embbeddings)


        how_hard = {0:0, 1:0, 2:0}
        transcript = {}

        for i, (emb, acti, label) in enumerate(zip(embbeddings, activations, labels)):
            hard = self.learn(emb, acti, label)
            # print('///// hard is:  ', hard)
            how_hard[hard] = how_hard[hard]+1


        transcript[0] = how_hard[0]/(i+1)*100
        transcript[1] = how_hard[1]/(i+1)*100
        transcript[2] = how_hard[2]/(i+1)*100
        return transcript



    def forward(self, x):
        # output shape: [batch_size, self.num_classes]
        embbeddings = super().forward(x) # shape: [batch_size, self.dimE]
        activations = self.report_activations(embbeddings)
        classBeliefs = self.hashmap(activations)
        return classBeliefs

    def assemble(self):
        # collect #1 protrait of sentinels (xi) and output
        return torch.stack(self.xis)

    def collect_Ls(self):
        # collect #2 protrait of sentinels (L) and output
        return torch.stack(self.Ls), torch.stack(self.Llevels) # 1-D


    def load_state_dict(self, network_dict, strict=True, **kwargs):

        super().load_state_dict(network_dict, strict=strict, **kwargs)
        try:
            self.roster = network_dict['roster']
            self.Ls = list(torch.tensor(network_dict['Ls'],device=DEVICE))
            self.Llevels = list(torch.tensor(network_dict['Llevels'],device=DEVICE))
            self.xis = list(torch.tensor(network_dict['xis'],device=DEVICE))
            print('\nSentinels load Sucess !  ')
        except KeyError:
            print('\nSentinels NOT loaded ...  ')
        return


    def state_dict(self):
        dic = super().state_dict()
        dic['roster'] = self.roster
        dic['xis'] = self.assemble()
        dic['Ls'], dic['Llevels'] = self.collect_Ls()
        return dic

    # def _get_name(self):
    #     name = super()._get_name()
    #     return f'Sentinel@{name}'

    def Nmm(self):
        return len(self.Llevels), max(self.Llevels).item(), min(self.Llevels).item()


    def exam(self, x, labels, cnt):

        embbeddings = super().forward(x) # shape: [batch_size, self.dimE]
        activations = self.report_activations(embbeddings)


        # transcript = {}

        for i, (acti, label) in enumerate(zip(activations, labels)):

            maxv, inel = acti.max(dim=0)  # both are scalar tensor
            if maxv==0:     # no sentinel can detect this sample
                hard = 1
            elif inel<self.roster[label] or inel>=self.roster[label+1]:   # sample closer to the sentinel of wrong class
                hard = 2
            else:   # sample is correctly detected; do nothing
                hard = 0


            cnt[hard] = cnt[hard]+1
            cnt['total'] = cnt['total']+1


        # transcript[0] = cnt[0]/(i+1)*100
        # transcript[1] = how_hard[1]/(i+1)*100
        # transcript[2] = how_hard[2]/(i+1)*100
        return cnt

    def train_batch_one_by_one(self, x, labels):

        # embbeddings = super().forward(x) # shape: [batch_size, self.dimE]
        # activations = self.report_activations(embbeddings)
        how_hard = {0:0, 1:0, 2:0}
        transcript = {}

        # for i, (emb, acti, label) in enumerate(zip(embbeddings, activations, labels)):
        #     hard = self.learn(emb, acti, label)
        #     # print('///// hard is:  ', hard)
        #     how_hard[hard] = how_hard[hard]+1

        for i, label in enumerate(labels):

            embbeddings_1 = super().forward(x[i:i+1]) # shape: [batch_size, self.dimE]
            activations_1 = self.report_activations(embbeddings_1)

            hard = self.learn(embbeddings_1[0], activations_1[0], label)
            # print('///// hard is:  ', hard)
            how_hard[hard] = how_hard[hard]+1

        transcript[0] = how_hard[0]/i*100
        transcript[1] = how_hard[1]/i*100
        transcript[2] = how_hard[2]/i*100
        return transcript


    def belief_classi_structured(self, bn, iclass):
        # output:
        #     belief_classi:    [batch_size]
        # 输入全部sentinel对全部sample的activation矩阵(activations) 和需要的class编号
        # 输出该class上的sentinel预测结果
        # output: a list of length batch_size, containing [iclass, belief_classi] or empty

        start, end = self.roster[iclass:iclass+2]
        bn_classi = bn[:,start:end]
        global_idx = start + torch.arange(end-start, dtype=int, device=DEVICE)
        belief_classi = (1/self.collect_Ls()[0][global_idx] * bn_classi).sum(dim=1, keepdim=False)  # [batch_size,1]
        return belief_classi







    # def report_one(sentinel, x):
    #     res = F.relu( 1 - distance(sentinel[0],x)/sentinel[1] )
    #     return res







def rand_pair_dist(arr):
    i,j = np.random.choice(len(arr),2,replace=0)
    linf = max(abs((arr[i]-arr[j])))
    print(linf)

def distance_Nnel(xi, x):
    # xi shape = [Nnel, dimE]
    # x shape = [batch_size, dimE]
    # maxvs/indices:  [batch_size, Nnel]
    mode = ['inf', 'cylinder', 2, 1] [0]
    x = x.unsqueeze(1) # [batch_size, 1, dimE]
    xi = xi.unsqueeze(0) # [1, Nnel, dimE]
    if mode=='inf':
        diff = xi-x  # [batch_size, Nnel, dimE]
        maxvs, indices = torch.max(torch.abs(diff), dim=-1, keepdim=False)
        return maxvs

    else:
        raise NotImplementedError
















def distance_1nel(xi, x):
    # xi shape = [1, dimE]
    # x shape = [batch_size, dimE]
    mode = ['inf', 'cylinder', 2, 1] [0]
    if mode=='inf':
        diff = xi-x
        return torch.max(torch.abs(diff), dim=1, keepdim=True)

    else:
        raise NotImplementedError

def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)


def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)


def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)


def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)


def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

# test()
