import torch
import torch.nn as nn
from torch.nn import Module as Module
from collections import OrderedDict
import torchvision
import math
import numpy as np
from .classifier import Interventional_Classifier, CosNorm_Classifier
from .swin_transformer import SwinTransformer
from .vision_transformer import VisionTransformer

class IDA(Module):
    def __init__(self,backbone="resnet101",num_classes=80,pretrain=None,use_intervention=False,heavy=False):
        super(IDA,self).__init__()
        if backbone=="resnet101":
            self.backbone = resnet101_backbone(pretrain)
        elif backbone=="swim_transformer":
            self.backbone = swimtrans_backbone(num_classes,pretrain)
        elif backbone=="swim_transformer_large":
            self.backbone = swimtrans_backbone(num_classes,pretrain,large = True)
        

        self.feat_dim = self.backbone.feat_dim
        self.use_intervention = use_intervention
        
        if not use_intervention:
            self.clf = nn.Linear(self.feat_dim,num_classes)
        else:
            self.clf = Interventional_Classifier(num_classes=num_classes, feat_dim=self.feat_dim, num_head=4, beta=0.03125, heavy=False)
    
    def forward(self,x):
        feats = self.backbone(x)
        
        if self.use_intervention:
            logits = self.clf(feats)
        else:
            logits = self.clf(feats.flatten(2).mean(-1))
        return feats, logits

 
class resnet101_backbone(Module):
    def __init__(self, pretrain):
        super(resnet101_backbone,self).__init__()
        res101 = torchvision.models.resnet101(pretrained=False)
        if pretrain:
            path = pretrain
            state = torch.load(path, map_location='cpu')
            if type(state)==dict and "state_dict" in state:
                res101 = nn.DataParallel(res101)
                res101.load_state_dict(state["state_dict"])
                res101 = res101.module
            else:
                res101.load_state_dict(state)
        numFit = res101.fc.in_features
        self.resnet_layer = nn.Sequential(*list(res101.children())[:-2])
        
        self.feat_dim = numFit

    def forward(self,x):
        feats = self.resnet_layer(x)
        
        return feats

class swimtrans_backbone(Module):
    def __init__(self,num_classes,pretrain,large=False):
        super(swimtrans_backbone,self).__init__()
        if large:
            self.model = SwinTransformer(img_size=384,patch_size=4,num_classes=num_classes,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),window_size=12)
        else:
            self.model = SwinTransformer(img_size=384,patch_size=4,num_classes=num_classes,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),window_size=12)
        if pretrain:
            path = pretrain
            state = torch.load(path, map_location='cpu')['model']
            filtered_dict = {k: v for k, v in state.items() if(k in self.model.state_dict() and 'head' not in k)}
            self.model.load_state_dict(filtered_dict,strict=False)
        numFit = self.model.num_features
        self.feat_dim = numFit
        del self.model.head

    def forward(self,x):
        feats = self.model.forward_features(x)
        return feats




























































