import torch
import torch.nn as nn
import torch.nn.functional as F

from clip.model import ModifiedResNet as CLIP_RN
from timm.models.regnet import RegNet

from .peft_modules import *


class RN_Tuner(nn.Module):
    """ All instance variables in this class will be optimized.
    """
    def __init__(self, cfg, rn_model, num_classes):
        super().__init__()

        if isinstance(rn_model, CLIP_RN):
            feat_dim = rn_model.output_dim
            dtype = rn_model.conv1.weight.dtype
        
        elif isinstance(rn_model, RegNet):
            feat_dim = rn_model.num_features
            dtype = rn_model.stem.conv.weight.dtype
        
        use_full_tuning = cfg.full_tuning
        use_bias_tuning = cfg.bias_tuning
        use_bn_tuning = cfg.bn_tuning
        use_ssf_attn = cfg.ssf_attn    
            

        if use_full_tuning:
            full_list = nn.ParameterList([
                param for name, param in rn_model.named_parameters()
            ])
        else:
            full_list = None

        if use_bias_tuning:
            bias_list = nn.ParameterList([
                param for name, param in rn_model.named_parameters()
                if name.endswith("bias")
            ])
        else:
            bias_list = None

        if use_bn_tuning:
            bn_list = nn.ModuleList([
                mod for name, mod in rn_model.named_modules()
                if isinstance(mod, nn.BatchNorm2d)
            ])
        else:
            bn_list = None

        if use_ssf_attn:
            ssf_attn = SSF(feat_dim, dtype=dtype)
        else:
            ssf_attn = None

        # To be optimized
        self.full_list = full_list
        self.bias_list = bias_list
        self.bn_list = bn_list
        self.ssf_attn = ssf_attn


class Peft_RN(nn.Module):
    def __init__(self, rn_model):
        super().__init__()
        
        if isinstance(rn_model, CLIP_RN):
            self.backbone = "CLIP-RN"
            self.stem = nn.Sequential(
                rn_model.conv1,
                rn_model.bn1,
                rn_model.conv2,
                rn_model.bn2,
                rn_model.conv3,
                rn_model.bn3,
                rn_model.avgpool,
                rn_model.relu
            )
            
            self.blocks = nn.Sequential(
                rn_model.layer1,
                rn_model.layer2,
                rn_model.layer3,
                rn_model.layer4
            )
            self.final_layer = rn_model.attnpool
            # self.out_dim = rn_model.attnpool.c_proj.bias.shape[0]
            self.dtype = rn_model.conv1.weight.dtype
            self.out_dim = rn_model.output_dim
        
        elif isinstance(rn_model, RegNet):
            self.stem = rn_model.stem
            self.blocks = nn.Sequential(
                rn_model.s1,
                rn_model.s2,
                rn_model.s3,
                rn_model.s4
            )
            self.final_layer = nn.Sequential(
                rn_model.final_conv,
                rn_model.head.global_pool,
                rn_model.head.drop
            )
            self.dtype = rn_model.stem.conv.weight.dtype
            self.out_dim = rn_model.num_features
            
    
    # @property
    # def dtype(self):
    #     return self.dtype

    def forward(self, x, tuner=None, head=None):
        
        x = x.to(self.dtype)
        # x = self.relu(self.bn1(self.conv1(x)))
        # x = self.relu(self.bn2(self.conv2(x)))
        # x = self.relu(self.bn3(self.conv3(x)))
        # x = self.avgpool(x)
        # x = self.layer1(x)
        # x = self.layer2(x)
        # x = self.layer3(x)
        # x = self.layer4(x)
        # x = self.attnpool(x)
        x = self.stem(x)
        # print(x.shape)
        x = self.blocks(x)
        # print(x.shape)
        x = self.final_layer(x)
        # print(x.shape)
        
        if tuner is not None and tuner.ssf_attn is not None:
            x = tuner.ssf_attn(x)

        if head is None:
            return x
        else:
            if isinstance(head, nn.ModuleDict):     
                return {"head1":head["head1"](x),
                        "head2":head["head2"](x)}
            else:
                return head(x)


