# import timm
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from mnasnet import mnasnet1_3
from regnet import regnet_y_800mf
from convnext import convnext_tiny
from swin_transformer import swin_t
from squeezenet import squeezenet1_1
from efficient import efficientnet_b0
from mobilenet import mobilenet_v3_large
from vgg import vgg16_bn, vgg19_bn
from torchvision.models.densenet import densenet121
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101


class VGGBackbone(nn.Module):
    def __init__(self):
        super(VGGBackbone, self).__init__()
        model = vgg19_bn()
        model_path = './pretrained/vgg19_bn-c79401a0.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 512
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class SqueezeBackbone(nn.Module):
    def __init__(self):
        super(SqueezeBackbone, self).__init__()
        model = squeezenet1_1()
        model_path = './pretrained/squeezenet1_1-b8a52dc0.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 960
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class MNasBackbone(nn.Module):
    def __init__(self):
        super(MNasBackbone, self).__init__()
        model = mnasnet1_3()
        model_path = './pretrained/mnasnet1_3-a4c69d6f.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 1280
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.layers(x)
        # print(x.shape)
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # return self.backbone.classifier(x)
        return x


class RegBackbone(nn.Module):
    def __init__(self):
        super(RegBackbone, self).__init__()
        model = regnet_y_800mf()
        model_path = './pretrained/regnet_y_800mf-1b27b58c.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 784
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.stem(x)
        x = self.backbone.trunk_output(x)
        # x = self.backbone.avgpool(x)
        # x = x.flatten(start_dim=1)
        # x = self.backbone.fc(x)
        return x


class MobileBackbone(nn.Module):
    def __init__(self):
        super(MobileBackbone, self).__init__()
        model = mobilenet_v3_large()
        model_path = './pretrained/mobilenet_v3_large-8738ca79.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 960
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class EfficientBackbone(nn.Module):
    def __init__(self):
        super(EfficientBackbone, self).__init__()
        model = efficientnet_b0()
        model_path = './pretrained/efficientnet_b0_rwightman-3dd342df.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 1280
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class ConvNextBackbone(nn.Module):
    def __init__(self):
        super(ConvNextBackbone, self).__init__()
        model = convnext_tiny()
        model_path = './pretrained/convnext_tiny.pth'
        model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
        
        self.num_channels = 768
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = self.backbone.classifier(x)
        return x


class ResNet50d(nn.Module):
    def __init__(self):
        super(ResNet50d, self).__init__()
        model = timm.create_model('resnet50d', pretrained=True)
        
        self.num_channels = model.fc.in_features
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.forward_features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class DenseNetBackbone(nn.Module):
    def __init__(self):
        super(DenseNetBackbone, self).__init__()
        model = densenet121(pretrained=True)
        
        self.num_channels = model.classifier.in_features
        self.backbone = model
        # del self.backbone.fc
        # del self.backbone.avgpool

        # freeze the backbone, only train the classification head
        self.backbone.features.requires_grad_(False)
        self.backbone.classifier.requires_grad_(False)


    def forward(self, x):
        # x: batch, 3, w, h
        features = self.backbone.features(x)
        out = F.relu(features, inplace=True)
        # out = F.adaptive_avg_pool2d(out, (1, 1))
        # out = torch.flatten(out, 1)
        # out = self.classifier(out)
        return out


class WideResNet50(nn.Module):
    def __init__(self):
        super(WideResNet50, self).__init__()
        model = timm.create_model('wide_resnet50_2', pretrained=True)
        
        self.num_channels = model.fc.in_features
        self.backbone = model

        # freeze the backbone
        for name, p in model.named_parameters():
            p.requires_grad = False


    def forward(self, x):
        # x: batch, 3, w, h
        x = self.backbone.forward_features(x)  # batch, channel, h, w
        # x = self.backbone.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.backbone.classifier(x)
        return x


class ResBackbone(nn.Module):
    def __init__(self, depth, train_backbone=True):
        super(ResBackbone, self).__init__()
        self.depth = depth
        if depth == 18:
            model = resnet18(pretrained=True)
            self.num_channels = model.fc.in_features
        elif depth == 34:
            model = resnet34(pretrained=True)
            self.num_channels = model.fc.in_features
        elif depth == 50:
            model = resnet50(pretrained=True)
            self.num_channels = model.fc.in_features
        elif depth == 101:
            model = resnet101(pretrained=True)
            self.num_channels = model.fc.in_features
        elif depth == 'densenet':
            model = DenseNetBackbone()
            self.num_channels = model.num_channels
        elif depth == 'mobile':
            model = MobileBackbone()
            self.num_channels = model.num_channels
        elif depth == 'efficient':
            model = EfficientBackbone()
            self.num_channels = model.num_channels
        elif depth == 'reg':
            model = RegBackbone()
            self.num_channels = model.num_channels
        elif depth == 'mnas':
            model = MNasBackbone()
            self.num_channels = model.num_channels
        elif depth == 'convnext':
            model = ConvNextBackbone()
            self.num_channels = model.num_channels
        elif depth == 'resnet50d':
            model = ResNet50d()
            self.num_channels = model.num_channels
        elif depth == 'wide_resnet50_2':
            model = WideResNet50()
            self.num_channels = model.num_channels
        elif depth == 'squeeze':
            model = SqueezeBackbone()
            self.num_channels = model.num_channels
        elif depth == 'vgg':
            model = VGGBackbone()
            self.num_channels = model.num_channels

        
        self.backbone = model

        if not train_backbone:
            # freeze the backbone, only train the classification head
            # self.backbone.conv1.requires_grad_(False)
            # self.backbone.bn1.requires_grad_(False)
            # self.backbone.relu.requires_grad_(False)
            # self.backbone.maxpool.requires_grad_(False)

            # self.backbone.layer1.requires_grad_(False)
            # self.backbone.layer2.requires_grad_(False)
            # self.backbone.layer3.requires_grad_(False)
            # self.backbone.layer4.requires_grad_(False)
            for name, p in self.backbone.named_parameters():
                p.requires_grad_(False)


    def forward(self, x):
        # x: batch, 3, w, h
        if type(self.depth) is str:
            x = self.backbone.forward(x)
        else:
            x = self.backbone.conv1(x)
            x = self.backbone.bn1(x)
            x = self.backbone.relu(x)
            x = self.backbone.maxpool(x)

            x = self.backbone.layer1(x)
            x = self.backbone.layer2(x)
            x = self.backbone.layer3(x)
            x = self.backbone.layer4(x)
        return x


class SSOD(nn.Module):
    def __init__(self, num_classes, depth=50, latent_dim=256, train_cls=True, train_backbone=True, block_sampler=True):
        super().__init__()
        self.depth = depth
        # random combination of image patches
        self.block_sampler = block_sampler

        # resnet backbone
        self.backbone = ResBackbone(depth=depth, train_backbone=train_backbone)

        # ood head
        self.ood_head = nn.Sequential(
            nn.Linear(self.backbone.num_channels, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, 2)
            )
        
        self.need_reduce = False

        # classification head
        if num_classes == 1000:
            if type(depth) is int:
                self.cls_head = self.backbone.backbone.fc
            elif depth == 'densenet':
                self.cls_head = self.backbone.backbone.backbone.classifier
            elif depth == 'convnext':
                self.cls_head = nn.Sequential(
                    self.backbone.backbone.backbone.classifier[0],
                    self.backbone.backbone.backbone.classifier[-1],
                )
            elif depth in ['efficient', 'mnas', 'mobile', 'vgg']:
                self.cls_head = self.backbone.backbone.backbone.classifier
            elif depth in ['reg', 'resnet50d', 'wide_resnet50_2']:
                self.cls_head = self.backbone.backbone.backbone.fc
            elif depth == 'squeeze':
                self.need_reduce = True
                self.cls_head = self.backbone.backbone.backbone.classifier

            if not train_cls:
                self.cls_head.requires_grad_(False)
        else:
            self.cls_head = nn.Sequential(
                nn.Linear(self.backbone.num_channels, 512),
                nn.ReLU(),
                nn.Linear(512, num_classes)
            )
        
        # feature pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if depth == 'vgg':
            self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    

    def cls_head_forward(self, feat):
        # feat: batch, channel, h, w
        if self.need_reduce:
            cls_logits = self.cls_head(feat)
            cls_logits = torch.flatten(cls_logits, 1)
        else:
            cls_feat = self.avg_pool(feat).reshape(feat.shape[0], -1)
            cls_logits = self.cls_head(cls_feat)

        mask_feat = rearrange(feat, 'b c h w -> b (h w) c')
        mask_logits = self.cls_head(mask_feat)
        return cls_logits, mask_logits
    

    def ood_head_forward(self, feat):
        # feat: batch, channel, h, w
        split_feat = rearrange(feat, 'b c h w -> b (h w) c')
        split_ood_logits = self.ood_head(split_feat)

        avg_feat = self.avg_pool(feat).reshape(feat.shape[0], -1)
        avg_ood_logits = self.ood_head(avg_feat)
        return split_ood_logits, avg_ood_logits


    # def loss(self, x, y, ood_weight=0.1, train_cls=False, thresh=0.99):
    #     # x: batch, 3, h, w, (float)
    #     # y: batch, (long)
    #     feat, cls_logits, mask_logits, split_ood_logits, avg_ood_logits = self.forward(x)

    #     if train_cls:
    #         cls_loss = nn.CrossEntropyLoss()(cls_logits, y)

    #     # ood: 0, id: 1
    #     # mask_logits: batch, (hw), 1000
    #     # mask_conf: batch, (hw)
    #     # mask_label: batch, (hw)
    #     mask_conf, mask_label = torch.max(nn.Softmax(-1)(mask_logits), dim=-1)
        
    #     # Here we only pick the good ID features, all the left are treated as OOD features
    #     # condition 1: predicted label equals to the target
    #     id_ood_label = (mask_label == y.unsqueeze(-1)).long()
    #     # condition 2: only confidence > 0.99 are treated as ID features
    #     mask_conf_binary = rearrange(mask_conf, 'b hw -> (b hw)') > thresh

    #     # used for training the ood head
    #     id_ood_label = rearrange(id_ood_label, 'b hw -> (b hw)') * mask_conf_binary.long()
    #     split_ood_logits = rearrange(split_ood_logits, 'b hw n -> (b hw) n')

    #     # LWB: loss wise balance
    #     id_loss = nn.CrossEntropyLoss(ignore_index=0)(split_ood_logits, id_ood_label)
    #     ood_loss = nn.CrossEntropyLoss(ignore_index=1)(split_ood_logits, id_ood_label)
    #     id_ood_loss = 0.5 * (id_loss + ood_loss)

    #     if train_cls:
    #         loss = cls_loss + ood_weight * id_ood_loss 
    #     else:
    #         loss = ood_weight * id_ood_loss 
    #     return feat, cls_logits, loss


    def loss(self, x, y, ood_weight=0.1, train_cls=False, thresh=0.99):
        # x: batch, 3, h, w, (float)
        # y: batch, (long)
        feat, cls_logits, mask_logits, split_ood_logits, avg_ood_logits = self.forward(x)

        if train_cls:
            cls_loss = nn.CrossEntropyLoss()(cls_logits, y)

        # ood: 0, id: 1
        # mask_logits: batch, (hw), 1000
        # mask_conf: batch, (hw)
        # mask_label: batch, (hw)
        mask_conf, mask_label = torch.max(nn.Softmax(-1)(mask_logits), dim=-1)
        
        # Here we only pick the good ID features as positive data
        # condition 1: predicted label equals to the target
        id_label = (mask_label == y.unsqueeze(-1)).long()
        # condition 2: only confidence > 0.99 are treated as ID features
        mask_conf_binary = rearrange(mask_conf, 'b hw -> (b hw)') > thresh
        
        # used for training the ood head
        id_label = rearrange(id_label, 'b hw -> (b hw)') * mask_conf_binary.long()
        split_ood_logits = rearrange(split_ood_logits, 'b hw n -> (b hw) n')

        # LWB: loss wise balance
        id_loss = nn.CrossEntropyLoss(ignore_index=0)(split_ood_logits, id_label)
        
        # Here we only pick the hard OOD features as negative data
        mask_conf_binary = rearrange(mask_conf, 'b hw -> (b hw)') < 0.3 * thresh
    
        # used for training the ood head
        # condition 2: predicted label with low target confidence
        id_label = (mask_label == y.unsqueeze(-1)).long()
        ood_label = (1.0 - rearrange(id_label, 'b hw -> (b hw)').long() * mask_conf_binary.long()).long()
        ood_loss = nn.CrossEntropyLoss(ignore_index=1)(split_ood_logits, ood_label)

        # LWB: loss wise balance
        id_ood_loss = 0.5 * (id_loss + ood_loss)

        if train_cls:
            loss = cls_loss + ood_weight * id_ood_loss
        else:
            loss = ood_weight * id_ood_loss
        return feat, cls_logits, loss


    def ood_infer(self, x):
        # x: batch, 3, w, h
        feat = self.backbone(x)
        
        # cls head
        # cls_logits: batch, num_classes
        # mask_logits: batch, (HW), num_classes
        # max_softmax: batch,
        # pred_label: batch,
        # mask_label: batch, (HW)
        cls_logits, mask_logits = self.cls_head_forward(feat)
        cls_conf = nn.Softmax(-1)(cls_logits)
        max_softmax, pred_label = torch.max(cls_conf, dim=-1)

        # ood head
        # split_ood_logits: batch, (HW), 2
        # avg_ood_logits: batch, 2
        # id_conf: batch,
        split_ood_logits, avg_ood_logits = self.ood_head_forward(feat)
        id_conf = torch.nn.Softmax(-1)(avg_ood_logits)[:, 1]
        
        # rectified posterior probability
        # rectified_p = max_softmax * id_conf
        rectified_p = id_conf
        return max_softmax, pred_label, rectified_p


    def forward(self, x):
        # x: batch, 3, W, H
        # feat: batch, channel, h, w
        feat = self.backbone(x)

        # ID inference
        # cls_logits: batch, num_classes, using global average pooling
        # mask_logits: batch, (h, w), num_classes, w/o using global average pooling
        cls_logits, mask_logits = self.cls_head_forward(feat)

        # OOD inference
        # split_ood_logits: batch, (h w), 1
        # avg_ood_logits: batch, 1
        split_ood_logits, avg_ood_logits = self.ood_head_forward(feat)
        return feat, cls_logits, mask_logits, split_ood_logits, avg_ood_logits
    
    def extract_id_ood_feature(self, x, y, thresh=0.99):
        # x: batch, 3, h, w, (float)
        # y: batch, (long)
        feat, cls_logits, mask_logits, split_ood_logits, avg_ood_logits = self.forward(x)

        # ood: 0, id: 1
        # mask_logits: batch, (hw), 1000
        # mask_conf: batch, (hw)
        # mask_label: batch, (hw)
        mask_conf, mask_label = torch.max(nn.Softmax(-1)(mask_logits), dim=-1)
        
        id_indices, ood_indices = list(), list()
        if y is not None:
            id_label = (mask_label == y.unsqueeze(-1)).long()
            mask_conf_binary = rearrange(mask_conf, 'b hw -> (b hw)') > thresh
            id_ood_label = rearrange(id_label, 'b hw -> (b hw)') * mask_conf_binary.long()

            id_indices = id_ood_label.nonzero().reshape(-1)
            ood_indices = (1 - id_ood_label).nonzero().reshape(-1)


        # id_ood_label = rearrange(id_label, 'b hw -> (b hw)')
        # ood_indices = (1 - id_ood_label).nonzero().reshape(-1)

        feat = rearrange(feat, 'b c h w -> b c (h w)')
        pool_feat = torch.mean(feat, dim=-1)

        feat = rearrange(feat, 'b c hw -> (b hw) c')
        id_feat, ood_feat = None, None
        if len(id_indices):
            id_feat = feat[id_indices]
        
        if len(ood_indices):
            ood_feat = feat[ood_indices]
        
        return id_feat, ood_feat, pool_feat


if __name__ == '__main__':
    x = torch.rand(10, 3, 224, 224)
    y = (1000 * torch.rand(10)).long()
    # model = SSOD(num_classes=1000, depth='efficient')
    # model = SSOD(num_classes=1000, depth='densenet')
    # model = SSOD(num_classes=1000, depth='reg')
    # model = SSOD(num_classes=1000, depth='mnas')
    # model = SSOD(num_classes=1000, depth='mobile')
    # model = SSOD(num_classes=1000, depth='convnext')
    # model = SSOD(num_classes=1000, depth='resnet50d')
    model = SSOD(num_classes=1000, depth='vgg')
    _, cls_logits, loss = model.loss(x, y, ood_weight=0.1, train_cls=True, thresh=0.99)
    print(loss)
    print(cls_logits.shape)

    id_feat, ood_feat, pool_feat = model.extract_id_ood_feature(x, y)
    print(id_feat, ood_feat, pool_feat.shape)

    # model = ResBackbone(depth=50, train_backbone=False)
    # feat = model(x)
    # print(feat.shape)

