#!/usr/bin/env python3

"""
ViT-related models
Note: models return logits instead of prob
"""
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from torchvision import models

from .build_vit_backbone import (
    build_vit_sup_models, build_swin_model,
    build_mocov3_model, build_mae_model
)
from .mlp import MLP
from ..utils import logging
logger = logging.get_logger("ELSE")


class ViT(nn.Module):
    """ViT-related model."""

    def __init__(self, cfg, load_pretrain=True, vis=False):
        super(ViT, self).__init__()

        if "prompt" in cfg.MODEL.TRANSFER_TYPE:
            prompt_cfg = cfg.MODEL.PROMPT
        else:
            prompt_cfg = None

        if cfg.MODEL.TRANSFER_TYPE != "end2end" and "prompt" not in cfg.MODEL.TRANSFER_TYPE:
            # linear, cls, tiny-tl, parital, adapter
            self.froze_enc = True
        else:
            # prompt, end2end, cls+prompt
            self.froze_enc = False
        
        if cfg.MODEL.TRANSFER_TYPE == "ELSE":
            ELSE_cfg = cfg.MODEL.ELSE
        else:
            ELSE_cfg = None

        self.build_backbone(
            prompt_cfg, cfg, ELSE_cfg, load_pretrain, vis=vis)
        self.cfg = cfg
        self.setup_side()
        self.setup_head(cfg)

    def setup_side(self):
        if self.cfg.MODEL.TRANSFER_TYPE != "side":
            self.side = None
        else:
            self.side_alpha = nn.Parameter(torch.tensor(0.0))
            m = models.alexnet(pretrained=True)
            self.side = nn.Sequential(OrderedDict([
                ("features", m.features),
                ("avgpool", m.avgpool),
            ]))
            self.side_projection = nn.Linear(9216, self.feat_dim, bias=False)

    def build_backbone(self, prompt_cfg, cfg, ELSE_cfg, load_pretrain, vis):
        transfer_type = cfg.MODEL.TRANSFER_TYPE
        self.enc, self.feat_dim = build_vit_sup_models(
        cfg.DATA.FEATURE, cfg.DATA.CROPSIZE, prompt_cfg, cfg.MODEL.MODEL_ROOT, ELSE_cfg, load_pretrain, vis
        )

        # linear, prompt, cls, cls+prompt, partial_1
        if transfer_type == "partial-1":
            total_layer = len(self.enc.transformer.encoder.layer)
            # tuned_params = [
            #     "transformer.encoder.layer.{}".format(i-1) for i in range(total_layer)]
            for k, p in self.enc.named_parameters():
                if "transformer.encoder.layer.{}".format(total_layer - 1) not in k and "transformer.encoder.encoder_norm" not in k: # noqa
                    p.requires_grad = False
        elif transfer_type == "partial-2":
            total_layer = len(self.enc.transformer.encoder.layer)
            for k, p in self.enc.named_parameters():
                if "transformer.encoder.layer.{}".format(total_layer - 1) not in k and "transformer.encoder.layer.{}".format(total_layer - 2) not in k and "transformer.encoder.encoder_norm" not in k: # noqa
                    p.requires_grad = False

        elif transfer_type == "partial-4":
            total_layer = len(self.enc.transformer.encoder.layer)
            for k, p in self.enc.named_parameters():
                if "transformer.encoder.layer.{}".format(total_layer - 1) not in k and "transformer.encoder.layer.{}".format(total_layer - 2) not in k and "transformer.encoder.layer.{}".format(total_layer - 3) not in k and "transformer.encoder.layer.{}".format(total_layer - 4) not in k and "transformer.encoder.encoder_norm" not in k: # noqa
                    p.requires_grad = False

        elif transfer_type == "linear" or transfer_type == "side":
            for k, p in self.enc.named_parameters():
                if "share" not in k:
                    p.requires_grad = False

        elif transfer_type == "tinytl-bias":
            for k, p in self.enc.named_parameters():
                if 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt" and prompt_cfg.LOCATION == "below":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and "embeddings.patch_embeddings.weight" not in k  and "embeddings.patch_embeddings.bias" not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt+bias":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt-noupdate":
            for k, p in self.enc.named_parameters():
                p.requires_grad = False

        elif transfer_type == "cls":
            for k, p in self.enc.named_parameters():
                if "cls_token" not in k:
                    p.requires_grad = False

        elif transfer_type == "cls-reinit":
            nn.init.normal_(
                self.enc.transformer.embeddings.cls_token,
                std=1e-6
            )

            for k, p in self.enc.named_parameters():
                if "cls_token" not in k:
                    p.requires_grad = False

        elif transfer_type == "cls+prompt":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and "cls_token" not in k:
                    p.requires_grad = False

        elif transfer_type == "cls-reinit+prompt":
            nn.init.normal_(
                self.enc.transformer.embeddings.cls_token,
                std=1e-6
            )
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and "cls_token" not in k:
                    p.requires_grad = False
        
        # ELSE
        elif transfer_type == "ELSE":
            for k, p in self.enc.named_parameters():
                if "ELSE" not in k:
                    p.requires_grad = False

        elif transfer_type == "end2end":
            logger.info("Enable all parameters update during training")

        else:
            raise ValueError("transfer type {} is not supported".format(
                transfer_type))

    def setup_head(self, cfg):
        self.head = MLP(
            input_dim=self.feat_dim,
            mlp_dims=[self.feat_dim] * self.cfg.MODEL.MLP_NUM + \
                [cfg.DATA.NUMBER_CLASSES], # noqa
            special_bias=True
        )

    def forward(self, x, return_feature=False):
        #cls_embeds = self.enc.forward_cls_layerwise(x)
        if self.side is not None:
            side_output = self.side(x)
            side_output = side_output.view(side_output.size(0), -1)
            side_output = self.side_projection(side_output)

        #if self.froze_enc and self.enc.training:
        #    self.enc.eval()
        f = self.enc(x)  # batch_size x self.feat_dim

        if self.side is not None:
            alpha_squashed = torch.sigmoid(self.side_alpha)
            x = alpha_squashed * x + (1 - alpha_squashed) * side_output

        if return_feature:
            return x, x
        #std = torch.std(x[:,1:,:], dim = 1)
        #f = torch.cat((x[:,0], std), dim = 1)
        x = self.head(f[:,0])
        #f = torch.cat((cls_embeds[-1], cls_embeds[-2]), dim=-1)
        return x, f[:,0]
    
    def forward_cls_layerwise(self, x):
        cls_embeds = self.enc.forward_cls_layerwise(x)
        return cls_embeds

    def get_features(self, x):
        """get a (batch_size, self.feat_dim) feature"""
        x = self.enc(x)  # batch_size x self.feat_dim
        #std = torch.std(x[:,1:,:], dim = 1)
        #f = torch.cat((x[:,0], std), dim = 1)
        return x[:,0]


class Swin(ViT):
    """Swin-related model."""

    def __init__(self, cfg):
        super(Swin, self).__init__(cfg)

    def build_backbone(self, prompt_cfg, cfg, ELSE_cfg, load_pretrain, vis):
        transfer_type = cfg.MODEL.TRANSFER_TYPE
        self.enc, self.feat_dim = build_swin_model(
            cfg.DATA.FEATURE, cfg.DATA.CROPSIZE,
            prompt_cfg, cfg.MODEL.MODEL_ROOT
        )

        # linear, prompt, cls, cls+prompt, partial_1
        if transfer_type == "partial-1":
            total_layer = len(self.enc.layers)
            total_blocks = len(self.enc.layers[-1].blocks)
            for k, p in self.enc.named_parameters():
                if "layers.{}.blocks.{}".format(total_layer - 1, total_blocks - 1) not in k and "norm.weight" != k and "norm.bias" != k: # noqa
                    p.requires_grad = False

        elif transfer_type == "partial-2":
            total_layer = len(self.enc.layers)
            for k, p in self.enc.named_parameters():
                if "layers.{}".format(total_layer - 1) not in k and "norm.weight" != k and "norm.bias" != k: # noqa
                    p.requires_grad = False

        elif transfer_type == "partial-4":
            total_layer = len(self.enc.layers)
            total_blocks = len(self.enc.layers[-2].blocks)

            for k, p in self.enc.named_parameters():
                if "layers.{}".format(total_layer - 1) not in k and "layers.{}.blocks.{}".format(total_layer - 2, total_blocks - 1) not in k and "layers.{}.blocks.{}".format(total_layer - 2, total_blocks - 2) not in k and "layers.{}.downsample".format(total_layer - 2) not in k and "norm.weight" != k and "norm.bias" != k: # noqa
                    p.requires_grad = False

        elif transfer_type == "linear" or transfer_type == "side":
            for k, p in self.enc.named_parameters():
                p.requires_grad = False

        elif transfer_type == "tinytl-bias":
            for k, p in self.enc.named_parameters():
                if 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt" and prompt_cfg.LOCATION in ["below"]:
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and "patch_embed" not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt+bias":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "end2end":
            logger.info("Enable all parameters update during training")

        else:
            raise ValueError("transfer type {} is not supported".format(
                transfer_type))


class SSLViT(ViT):
    """moco-v3 and mae model."""

    def __init__(self, cfg):
        super(SSLViT, self).__init__(cfg)

    def build_backbone(self, prompt_cfg, cfg, ELSE_cfg, load_pretrain, vis):
        if "moco" in cfg.DATA.FEATURE:
            build_fn = build_mocov3_model
        elif "mae" in cfg.DATA.FEATURE:
            build_fn = build_mae_model

        self.enc, self.feat_dim = build_fn(
            cfg.DATA.FEATURE, cfg.DATA.CROPSIZE,
            prompt_cfg, cfg.MODEL.MODEL_ROOT, ELSE_cfg=ELSE_cfg
        )

        transfer_type = cfg.MODEL.TRANSFER_TYPE
        # linear, prompt, cls, cls+prompt, partial_1
        if transfer_type == "partial-1":
            total_layer = len(self.enc.blocks)
            for k, p in self.enc.named_parameters():
                if "blocks.{}".format(total_layer - 1) not in k and "fc_norm" not in k and k != "norm": # noqa
                    p.requires_grad = False
        elif transfer_type == "partial-2":
            total_layer = len(self.enc.blocks)
            for k, p in self.enc.named_parameters():
                if "blocks.{}".format(total_layer - 1) not in k and "blocks.{}".format(total_layer - 2) not in k and "fc_norm" not in k and k != "norm": # noqa
                    p.requires_grad = False

        elif transfer_type == "partial-4":
            total_layer = len(self.enc.blocks)
            for k, p in self.enc.named_parameters():
                if "blocks.{}".format(total_layer - 1) not in k and "blocks.{}".format(total_layer - 2) not in k and "blocks.{}".format(total_layer - 3) not in k and "blocks.{}".format(total_layer - 4) not in k and "fc_norm" not in k and k != "norm": # noqa
                    p.requires_grad = False

        elif transfer_type == "linear" or transfer_type == "sidetune":
            for k, p in self.enc.named_parameters():
                p.requires_grad = False

        elif transfer_type == "tinytl-bias":
            for k, p in self.enc.named_parameters():
                if 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt+bias":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and 'bias' not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt" and prompt_cfg.LOCATION == "below":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k and "patch_embed.proj.weight" not in k  and "patch_embed.proj.bias" not in k:
                    p.requires_grad = False

        elif transfer_type == "prompt":
            for k, p in self.enc.named_parameters():
                if "prompt" not in k:
                    p.requires_grad = False

        elif transfer_type == "end2end":
            logger.info("Enable all parameters update during training")
        
        # adapter
        elif transfer_type == "adapter":
            for k, p in self.enc.named_parameters():
                if "adapter" not in k:
                    p.requires_grad = False

        else:
            raise ValueError("transfer type {} is not supported".format(
                transfer_type))

class Swish(nn.Module):
    def __init__(self,beta=1.0):
        super().__init__()
        self.beta = beta

    def forward(self,x):
        return x * torch.sigmoid(self.beta * (x))

class RoundSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input

class RoundLayer(nn.Module):
    def __init__(self):
        super(RoundLayer, self).__init__()

    def forward(self, x):
        return RoundSTE.apply(x)

class MLP_weight(nn.Module):
    def __init__(self, input_dim, class_num):
        super(MLP_weight, self).__init__()  # 确保正确调用父类的__init__方法
        self.norm1 = nn.BatchNorm1d(input_dim * 2)
        self.fc1 = nn.Linear(input_dim + class_num, input_dim * 2)
        self.norm2 = nn.BatchNorm1d(input_dim)
        self.fc2 = nn.Linear(input_dim * 2, input_dim)
        self.norm3 = nn.BatchNorm1d(input_dim // 2)
        self.fc3 = nn.Linear(input_dim, input_dim // 2)
        self.fc4 = nn.Linear(input_dim // 2, 1)
        #self.drop = nn.Dropout(0.0)
        self._initialize_weights()

        self.act = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.round_layer = RoundLayer()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.act(self.norm1(self.fc1(x)))
        x = self.act(self.norm2(self.fc2(x)))
        x = self.act(self.norm3(self.fc3(x)))
        x = self.sigmoid(self.fc4(x))
        print(x)
        return self.round_layer(x)

class AutoEncoder(nn.Module):
    def __init__(self, input_dim):
        super(AutoEncoder, self).__init__()  # 确保正确调用父类的__init__方法
        self.norm1 = nn.BatchNorm1d(input_dim * 2)
        self.up1 = nn.Linear(input_dim, input_dim * 2)
        self.norm2 = nn.BatchNorm1d(input_dim * 4)
        self.up2 = nn.Linear(input_dim * 2, input_dim * 4)
        self.norm3 = nn.BatchNorm1d(input_dim * 8)
        self.up3 = nn.Linear(input_dim * 4, input_dim * 8)
        self.norm4 = nn.BatchNorm1d(input_dim * 4)
        self.down1 = nn.Linear(input_dim * 8, input_dim * 4)
        self.norm5 = nn.BatchNorm1d(input_dim * 2)
        self.down2 = nn.Linear(input_dim * 4, input_dim * 2)
        self.norm6 = nn.BatchNorm1d(input_dim)
        self.down3 = nn.Linear(input_dim * 2, input_dim)
        self.act = nn.ReLU()
        self._initialize_weights()
        self.drop = nn.Dropout(0.1)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.drop(self.act(self.norm1(self.up1(x))))
        x2 = self.drop(self.act(self.norm2(self.up2(x1))))
        x3 = self.drop(self.act(self.norm3(self.up3(x2))))
        x4 = self.drop(self.act(self.norm4(self.down1(x3))))
        x5 = self.drop(self.act(self.norm5(self.down2(x4))))
        x6 = self.drop(self.act(self.norm6(self.down3(x5))))
        return x6

class ParallelMLP(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(ParallelMLP, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(num_heads):
            mlp = AutoEncoder(input_dim)
            self.mlps.append(mlp)

    def forward(self, x):
        #outputs = [mlp(x) for mlp in self.mlps]
        outputs = [self.mlps[i](x[:,i]) for i in range(12)]
        return torch.stack(outputs, dim=1)  # 将所有并行MLP的输出堆叠成一个矩阵

class Sigmoid_new(nn.Module):
    def __init__(self, alpha = 5.0, beta = 1.0):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(x)
        return torch.reciprocal(1 + torch.exp(-self.alpha * (x - self.beta)))

class CustomSigmoid(nn.Module):
    def __init__(self, a, b):
        super(CustomSigmoid, self).__init__()
        self.a = a
        self.b = nn.Parameter(torch.tensor([b]))
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(x)
        print("alpha:", self.a)
        print("beta:", self.b)
        return 1 / (1 + torch.exp(-self.a * (x - self.b)))

class TransNet(nn.Module):
    def __init__(self, cfg, **kwargs):
        super(TransNet, self).__init__()  # 确保正确调用父类的__init__方法
        self.backbone = ViT(cfg)
        #self.weight = MLP_weight(input_dim = cfg.MODEL.INPUT_DIM, class_num = cfg.DATA.NUMBER_CLASSES)
        #self.autoencoder = AutoEncoder(input_dim = cfg.MODEL.INPUT_DIM)
        self.offset = nn.Parameter(torch.randn(1, cfg.MODEL.INPUT_DIM))
        self.classifier = nn.Linear(cfg.MODEL.INPUT_DIM, 200)
        self._initialize_weights()
        #self.Sigmoid = Sigmoid_new(10.0, 1.0)
    
    def _initialize_weights(self):
        torch.nn.init.kaiming_normal_(self.classifier.weight, nonlinearity='relu')
        torch.nn.init.constant_(self.classifier.bias, 0)
    
    def forward(self, x):
        label, feature = self.backbone(x)
        #feature_copy = feature.clone()
        b, dim = feature.shape
        #mlp_in = torch.cat((feature_copy, label), dim = 1)
        #weight_cls = self.weight(mlp_in)
        #auto = self.autoencoder(feature_copy)
        #print("w:", weight_cls)
        #weight = (weight_cls * auto)
        #print("weight:", weight)
        '''
        l1_loss = 0 
        for param in self.autoencoder.parameters():
            l1_loss += torch.sum(torch.abs(param))
        '''
        feature = feature + self.offset
        #x = self.backbone.head(feature)
        x = self.classifier(feature)
        weight_cls = None
        return x, label, weight_cls