import torch
import torch.nn as nn

from model.regression import IQAHead
from model.swin_transformer import SwinTransformer as swin_tiny
from model.vit_mae import ViT_MAE
from model.fusion import guide_fusion
from model.mi_estimators import CLUB, CLUBSample

class DisPA(nn.Module):
    def __init__(self, swin_checkpoint_path=None, vit_checkpoint_path=None):
        super().__init__()
        self.swin = swin_tiny()

        # swin_checkpoint_path, vit_checkpoint_path = None, None
        self.vit = ViT_MAE(checkpoint_path=vit_checkpoint_path)
        self.miestimator = CLUBSample(x_dim=768,y_dim=768,hidden_size=768)
        # self.fusion = guide_fusion(embed_dim=768)
        self.head = IQAHead(in_channels=768*2)
        if swin_checkpoint_path is not None:
            self.initialize_weight(swin_checkpoint_path)
        
    def initialize_weight(self, swin_checkpoint_path=None):
        weights = torch.load(swin_checkpoint_path)['model']
        self.swin.load_state_dict(weights, strict=False)
    
    def forward(self, imgs=None, frags=None, opt=None):
        assert opt in ['optimizing_estimators', 'optimizing_networks', 'inference']
        B, num_view, C, H, W = imgs.shape
        ## freeze the weights of ViT
        # self.vit.eval()
        # with torch.no_grad():
        feat_vit = self.vit(imgs.view(-1,C,H,W), mode='only_encode').view(B, num_view, -1)
        feat_vit = feat_vit.max(dim=1)[0]
        feat_swin = self.swin(frags).mean(dim=1)
        
        if opt == 'optimizing_estimators':
            self.miestimator.train()
            loss = self.miestimator.learning_loss(feat_vit,feat_swin)
            return loss
        else:
            self.miestimator.eval()
            estimate_mi = self.miestimator(feat_vit,feat_swin)

            # feat_fuse = self.fusion(feat_vit, feat_swin)
            feat_fuse = torch.cat((feat_vit, feat_swin), dim=-1)
            pred_mos =  self.head(feat_fuse)
            
            return estimate_mi, pred_mos
            



