import os
import torch
import torch.nn as nn
import core
import torch.nn.functional as F
import torchvision.models as models

  
class U64R18(nn.Module):
    def __init__(self, save_path, unet_path, shuffle_path):
        super(U64R18, self).__init__()
        self.init_net(save_path)
        self.init_unet(unet_path)
        self.init_shuffle(shuffle_path)

    def init_net(self, save_path):
        self.net = core.models.ResNet(18)
        model_dir = []
        for dirnames in os.listdir(save_path):
            if dirnames.startswith('ResNet18'):
                model_dir.append(dirnames)
        self.net.load_state_dict(torch.load(os.path.join(save_path, model_dir[0], 'ckpt_epoch_200.pth')))

    def init_unet(self, path):
        self.unet = core.models.UNetLittle(args=None, n_channels=3, n_classes=3, first_channels=64)
        self.unet.load_state_dict(torch.load(path))

    def init_shuffle(self, path):
        self.shuffle = torch.load(path)
    
    def label_shuffle(self, label):
        label_new = torch.zeros_like(label)
        index = torch.from_numpy(self.shuffle).repeat(label.shape[0], 1).cuda()
        label_new = label_new.scatter(1, index, label)
        return label_new
    
    def forward(self, image):
        image = self.unet(image)
        self.X_adv = torch.clamp(image, 0, 1)
        Y_adv = self.net(self.X_adv)
        Y_adv = F.softmax(Y_adv, 1)
        return self.label_shuffle(Y_adv)
