import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.backbone import res50, bbn_res50, res32_cifar, bbn_res32_cifar, res101, res10, res18, res152, resnext50, \
    resnext101, resnext152, res32_512, res32_cifar_rate
from lib.backbone.resnetforcifar import resnet18, resnet34
from lib.modules import GAP, Identity, FCNorm, NonLinearNeck, LinearNeck

import pdb


class BiasLayer(nn.Module):
    def __init__(self, current_classes_num, active_classes_num):
        super(BiasLayer, self).__init__()
        self.params = nn.Parameter(torch.Tensor([1, 0]))
        self.current_classes_num = current_classes_num
        self.active_classes_num = active_classes_num

    def forward(self, x):
        x = x[:, 0: self.active_classes_num]
        x[:, -self.current_classes_num:] *= self.params[0]
        x[:, -self.current_classes_num:] += self.params[1]
        return x


class DDC_Network(nn.Module):
    def __init__(self, cfg, mode="train", num_classes=1000, rate=1):
        super(DDC_Network, self).__init__()
        pretrain = (
            True
            if mode == "train"
               and cfg.BACKBONE.PRETRAINED_BACKBONE != ""
            else False
        )

        self.num_classes = num_classes
        self.cfg = cfg
        self.backbone = eval(self.cfg.BACKBONE.TYPE)(rate=rate)
        self.feature_len = self.get_feature_length()
        self.last_downsample_module = self._get_module()
        self.neck = None
        if self.cfg.CLASSIFIER.NECK.ENABLE:
            self.neck = self._get_neck_module()
        self.classifier = self._get_classifer()
        self.cls_for_distill = None
        if self.cfg.DISTILL.ENABLE:
            self.cls_for_distill = nn.Linear(self.feature_len, self.num_classes, bias=self.cfg.CLASSIFIER.BIAS)

    def forward(self, x, **kwargs):
        if "is_nograd" in kwargs:
            with torch.no_grad():
                return self._forward_func(x, **kwargs)
        else:
            return self._forward_func(x, **kwargs)

    def _forward_func(self, x, **kwargs):
        if "feature_2_classifier" in kwargs:
            features = x
            return self.classifier(features)
        if "feature_flag" in kwargs and "classifier_flag" in kwargs:
            features = self.extract_feature(x)
            if isinstance(self.classifier, torch.nn.modules.container.Sequential):
                if self.cfg.DISTILL.ENABLE:
                    out = [self.classifier[0](features),
                           self.classifier[-1](features)]
                    out_for_distill = self.cls_for_distill(features)
                    return out, out_for_distill, features
                else:
                    out = [self.classifier[0](features),
                           self.classifier[-1](features)]
                    return out, features
            else:
                if self.cfg.DISTILL.ENABLE:
                    out = self.classifier(features)
                    out_for_distill = self.cls_for_distill(features)
                    return out, out_for_distill, features
                else:
                    return self.classifier(features), features

        if "feature_flag" in kwargs or "feature_cb" in kwargs or "feature_rb" in kwargs:
            return self.extract_feature(x)
        elif "classifier_flag" in kwargs:
            x = self.extract_feature(x)
            if isinstance(self.classifier, torch.nn.modules.container.Sequential):
                if self.cfg.DISTILL.ENABLE:
                    out = [self.classifier[0](x),
                           self.classifier[-1](x)]
                    out_for_distill = self.cls_for_distill(x)
                    return out, out_for_distill
                else:
                    out = [self.classifier[0](x),
                           self.classifier[-1](x)]
                    return out
            else:
                if self.cfg.DISTILL.ENABLE:
                    out = self.classifier(x)
                    out_for_distill = self.cls_for_distill(x)
                    return out, out_for_distill
                else:
                    return self.classifier(x)

        x = self.backbone(x)
        x = self.last_downsample_module(x)
        x = x.view(x.shape[0], -1)

        if isinstance(self.classifier, torch.nn.modules.container.Sequential):
            if self.cfg.DISTILL.ENABLE:
                out = [self.classifier[0](x),
                       self.classifier[-1](x)]
                out_for_distill = self.cls_for_distill(x)
                if self.cfg.CLASSIFIER.NECK.ENABLE:
                    x_f = self.neck(x)
                    return out, out_for_distill, x_f
                else:
                    return out, out_for_distill
            else:
                out = [self.classifier[0](x),
                       self.classifier[-1](x)]
                if self.cfg.CLASSIFIER.NECK.ENABLE:
                    x_f = self.neck(x)
                    return out, x_f
                else:
                    return out
        else:
            if self.cfg.DISTILL.ENABLE:
                out = self.classifier(x)
                out_for_distill = self.cls_for_distill(x)
                return out, out_for_distill
            else:
                return self.classifier(x)

    def extract_feature(self, x):
        x = self.backbone(x)
        x = self.last_downsample_module(x)
        x = x.view(x.shape[0], -1)
        return x

    def freeze_backbone(self):
        self.logger.info("Freezing backbone .......")
        for p in self.backbone.parameters():
            p.requires_grad = False

    def refreeze_last_block(self):
        self.logger.info("Refreezing the last block of backbone......")
        for p in self.backbone.layer4[2].parameters():
            p.requires_grad = True

    def refreeze_last_stage(self):
        self.logger.info("Refreezing the last stage of backbone......")
        for p in self.backbone.layer4.parameters():
            p.requires_grad = True

    def load_backbone_model(self, backbone_path=""):
        self.backbone.load_model(backbone_path)

    def load_model(self, model_path, logger=None):
        pretrain_dict = torch.load(
            model_path, map_location="cpu" if self.cfg.CPU_MODE else "cuda"
        )
        pretrain_dict = pretrain_dict['state_dict'] if 'state_dict' in pretrain_dict else pretrain_dict
        model_dict = self.state_dict()
        from collections import OrderedDict
        new_dict = OrderedDict()
        for k, v in pretrain_dict.items():
            if k.startswith("module"):
                new_dict[k[7:]] = v
            else:
                new_dict[k] = v

        model_dict.update(new_dict)
        self.load_state_dict(model_dict)
        if logger:
            logger.info("Model has been loaded...")

    def get_feature_length(self):
        if self.cfg.BACKBONE.TYPE in ['res32_cifar']:
            num_features = 64
        elif self.cfg.BACKBONE.TYPE in ['res32_cifar_rate']:
            num_features = int(64 * self.cfg.rate)
        elif self.cfg.BACKBONE.TYPE in ['res10', 'res18', 'res32_512', 'res34', 'resnet18']:
            num_features = 512
        else:
            num_features = 2048
        return num_features

    def _get_module(self):
        module_type = self.cfg.MODULE.TYPE
        if module_type == "GAP":
            module = GAP()
        elif module_type == "Identity":
            module = Identity()
        else:
            raise NotImplementedError

        return module

    def _get_neck_module(self):
        if self.cfg.CLASSIFIER.NECK.TYPE == 'NonLinear':  # MLP
            assert self.cfg.CLASSIFIER.NECK.NUM_FEATURES == self.feature_len, \
                "self.cfg.CLASSIFIER.NECK.NUM_FEATURES != self.feature_len"
            return NonLinearNeck(self.cfg.CLASSIFIER.NECK.NUM_FEATURES, self.cfg.CLASSIFIER.NECK.NUM_OUT,
                                 self.cfg.CLASSIFIER.NECK.HIDDEN_DIM)
        elif self.cfg.CLASSIFIER.NECK.TYPE == 'Identity':  # Identity
            return Identity()
        else:
            assert self.cfg.CLASSIFIER.NECK.NUM_FEATURES == self.feature_len, \
                "self.cfg.CLASSIFIER.NECK.NUM_FEATURES != self.feature_len"
            return LinearNeck(self.cfg.CLASSIFIER.NECK.NUM_FEATURES, self.cfg.CLASSIFIER.NECK.NUM_OUT)  # Linear_layer

    def _get_classifer(self):
        bias_flag = self.cfg.CLASSIFIER.BIAS
        if self.cfg.CLASSIFIER.TYPE == "FCNorm":
            classifier = FCNorm(self.feature_len, self.num_classes)
        elif self.cfg.CLASSIFIER.TYPE == "FC":
            classifier = nn.Linear(self.feature_len, self.num_classes, bias=bias_flag)
        elif self.cfg.CLASSIFIER.TYPE == 'LDA':
            classifier = []
            for i in range(2):
                classifier.append(nn.Linear(self.feature_len, self.num_classes, bias=bias_flag))
            classifier = nn.Sequential(*classifier)
        else:
            raise NotImplementedError

        return classifier


class Network(nn.Module):
    def __init__(self, cfg, mode="train", num_classes=1000):
        super(Network, self).__init__()
        pretrain = (
            True
            if mode == "train"
               and cfg.BACKBONE.PRETRAINED_BACKBONE != ""
            else False
        )

        self.num_classes = num_classes
        self.cfg = cfg

        self.backbone = eval(self.cfg.BACKBONE.TYPE)(
            self.cfg,
            pretrain=pretrain,
            pretrained_backbone=cfg.BACKBONE.PRETRAINED_BACKBONE,
            last_layer_stride=2,
        )
        self.feature_len = self.get_feature_length()
        self.last_downsample_module = self._get_module()
        self.neck = None
        self.classifier = self._get_classifer()
        self.cls_for_distill = None
        if self.cfg.DISTILL.ENABLE:
            self.cls_for_distill = nn.Linear(self.feature_len, self.num_classes, bias=self.cfg.CLASSIFIER.BIAS)

    def forward(self, x, **kwargs):
        if "is_nograd" in kwargs:
            with torch.no_grad():
                return self._forward_func(x, **kwargs)
        else:
            return self._forward_func(x, **kwargs)

    def _forward_func(self, x, **kwargs):
        if "feature_flag" in kwargs or "feature_cb" in kwargs or "feature_rb" in kwargs:
            return self.extract_feature(x)
        elif "classifier_flag" in kwargs:
            x = self.extract_feature(x)
            if isinstance(self.classifier, torch.nn.modules.container.Sequential):
                if self.cfg.DISTILL.ENABLE:
                    out = [self.classifier[0](x),
                           self.classifier[-1](x)]
                    out_for_distill = self.cls_for_distill(x)
                    return out, out_for_distill
                else:
                    out = [self.classifier[0](x),
                           self.classifier[-1](x)]
                    return out
            else:
                if self.cfg.DISTILL.ENABLE:
                    out = self.classifier(x)
                    out_for_distill = self.cls_for_distill(x)
                    return out, out_for_distill
                else:
                    return self.classifier(x)

        x = self.backbone(x)
        x = self.last_downsample_module(x)
        x = x.view(x.shape[0], -1)

        if isinstance(self.classifier, torch.nn.modules.container.Sequential):
            if self.cfg.DISTILL.ENABLE:
                out = [self.classifier[0](x),
                       self.classifier[-1](x)]
                out_for_distill = self.cls_for_distill(x)
                if self.cfg.CLASSIFIER.NECK.ENABLE:
                    x_f = self.neck(x)
                    return out, out_for_distill, x_f
                else:
                    return out, out_for_distill
            else:
                out = [self.classifier[0](x),
                       self.classifier[-1](x)]
                if self.cfg.CLASSIFIER.NECK.ENABLE:
                    x_f = self.neck(x)
                    return out, x_f
                else:
                    return out
        else:
            if self.cfg.DISTILL.ENABLE:
                out = self.classifier(x)
                out_for_distill = self.cls_for_distill(x)
                return out, out_for_distill
            else:
                return self.classifier(x)

    def extract_feature(self, x):
        x = self.backbone(x)
        x = self.last_downsample_module(x)
        x = x.view(x.shape[0], -1)
        return x

    def freeze_backbone(self):
        self.logger.info("Freezing backbone .......")
        for p in self.backbone.parameters():
            p.requires_grad = False

    def refreeze_last_block(self):
        self.logger.info("Refreezing the last block of backbone......")
        for p in self.backbone.layer4[2].parameters():
            p.requires_grad = True

    def refreeze_last_stage(self):
        self.logger.info("Refreezing the last stage of backbone......")
        for p in self.backbone.layer4.parameters():
            p.requires_grad = True

    def load_backbone_model(self, backbone_path=""):
        self.backbone.load_model(backbone_path)

    def load_model(self, model_path):
        pretrain_dict = torch.load(
            model_path, map_location="cpu" if self.cfg.CPU_MODE else "cuda"
        )
        pretrain_dict = pretrain_dict['state_dict'] if 'state_dict' in pretrain_dict else pretrain_dict
        model_dict = self.state_dict()
        from collections import OrderedDict
        new_dict = OrderedDict()
        for k, v in pretrain_dict.items():
            if k.startswith("module"):
                new_dict[k[7:]] = v
            else:
                new_dict[k] = v

        model_dict.update(new_dict)
        self.load_state_dict(model_dict)
        self.logger.info("Model has been loaded...")

    def get_feature_length(self):
        if "cifar" in self.cfg.BACKBONE.TYPE:
            num_features = 64
        elif self.cfg.BACKBONE.TYPE in ['res10', 'res18', 'res32_512', 'res34']:
            num_features = 512
        else:
            num_features = 2048
        return num_features

    def _get_module(self):
        module_type = self.cfg.MODULE.TYPE
        if module_type == "GAP":
            module = GAP()
        elif module_type == "Identity":
            module = Identity()
        else:
            raise NotImplementedError

        return module

    def _get_neck_module(self):
        if self.cfg.CLASSIFIER.NECK.TYPE == 'NonLinear':  # MLP
            assert self.cfg.CLASSIFIER.NECK.NUM_FEATURES == self.feature_len, \
                "self.cfg.CLASSIFIER.NECK.NUM_FEATURES != self.feature_len"
            return NonLinearNeck(self.cfg.CLASSIFIER.NECK.NUM_FEATURES, self.cfg.CLASSIFIER.NECK.NUM_OUT,
                                 self.cfg.CLASSIFIER.NECK.HIDDEN_DIM)
        elif self.cfg.CLASSIFIER.NECK.TYPE == 'Identity':  # Identity
            return Identity()
        else:
            assert self.cfg.CLASSIFIER.NECK.NUM_FEATURES == self.feature_len, \
                "self.cfg.CLASSIFIER.NECK.NUM_FEATURES != self.feature_len"
            return LinearNeck(self.cfg.CLASSIFIER.NECK.NUM_FEATURES, self.cfg.CLASSIFIER.NECK.NUM_OUT)  # Linear_layer

    def _get_classifer(self):
        bias_flag = self.cfg.CLASSIFIER.BIAS
        if self.cfg.CLASSIFIER.TYPE == "FCNorm":
            classifier = FCNorm(self.feature_len, self.num_classes)
        elif self.cfg.CLASSIFIER.TYPE == "FC":
            classifier = nn.Linear(self.feature_len, self.num_classes, bias=bias_flag)
        elif self.cfg.CLASSIFIER.TYPE == 'LDA':
            classifier = []
            for i in range(2):
                classifier.append(nn.Linear(self.feature_len, self.num_classes, bias=bias_flag))
            classifier = nn.Sequential(*classifier)
        else:
            raise NotImplementedError

        return classifier
