'''
# author: Younghun Kim
# email: younghun1664@kaist.ac.kr
# date: 2025-0512
# description: Class for the SELFIDetector

Reference:
TODO: Add reference here
'''


import logging
import torch
import torch.nn as nn
import torch.nn.functional as F

from metrics.base_metrics_class import calculate_metrics_for_train
from transformers import AutoProcessor, CLIPModel, ViTModel, ViTConfig

from .base_detector import AbstractDetector
from detectors import DETECTOR
from networks import BACKBONE
from loss import LOSSFUNC

logger = logging.getLogger(__name__)

@DETECTOR.register_module(module_name='selfi')
class SELFIDetector(AbstractDetector):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.backbone = self.build_backbone(config)
        self.dropout = config['backbone_config'].get('dropout', False)
        if self.dropout:
            self.dropout_layer = nn.Dropout(p=self.dropout)

        if config['backbone_name'] == 'clip':
            self.feat_dim = 768
            self.id_feat_dim = 512
        else:
            raise ValueError(f"Unsupported backbone name: {config['backbone_name']}")
   
        # Forgery-Aware Identity Adapter (FAIA)
        self.iresnet = self.build_face_recognition_model(config)
        self.proj_id = nn.Linear(self.id_feat_dim, self.feat_dim, bias=False)
        self.proj_id_norm = nn.LayerNorm(self.feat_dim)
        self.identity_real_fake_head = nn.Linear(self.feat_dim, 2)

        # Identity-Aware Fusion Module (IAFM)
        self.feat_norm = nn.LayerNorm(self.feat_dim)  # Feature normalization
        self.relevance_predictor = nn.Sequential(
            nn.Linear(self.feat_dim*2, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
           
        # Real or Fake Classifier
        self.real_or_fake_head = nn.Linear(self.feat_dim, 2)  # Default output dim

        # Losses
        self.cls_loss_func = self.build_loss(config['loss_func']['cls_loss'])
        self.cls_weight = config['loss_weight'].get('cls_weight', 1)
        self.identity_cls_weight = config['loss_weight'].get('identity_cls_weight', 1)
            
    def build_backbone(self, config):
        _, backbone = get_clip_visual(model_name="openai/clip-vit-base-patch16")
        logger.info('CLIP - Load pretrained model successfully!')
        return backbone

    def build_loss(self, name, tau=None):
        loss_class = LOSSFUNC[name]
        return loss_class(tau) if tau else loss_class()

    def build_face_recognition_model(self, config):
        backbone_class = BACKBONE[config['face_model_name']]
        pretrained_path = config.get('face_model_pretrained', None)
        if not pretrained_path:
            raise ValueError("Pretrained weights for IResNet not found in config['pretrained_path'].")

        iresnet = backbone_class(pretrained_path)
        for param in iresnet.parameters():
            param.requires_grad = False
        return iresnet

    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        # Calculate Overall Loss
        real_fake_loss = self.cls_loss_func(pred_dict['cls'], data_dict['label'])
        forgery_aware_guidance_loss = self.cls_loss_func(pred_dict['identity_cls'], data_dict['label'])
        loss = (self.cls_weight * real_fake_loss) + (self.identity_cls_weight * forgery_aware_guidance_loss)
        
        return {
            'overall': loss,
            'real_or_fake': real_fake_loss,
            'forgery_aware_guidance_loss': forgery_aware_guidance_loss,
        }

    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        label, pred = data_dict['label'], pred_dict['cls']
        auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
        return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}

    def features(self, data_dict: dict) -> torch.tensor:
        return self.backbone(data_dict['image'])['pooler_output']

    def get_face_embedding(self, x) -> torch.tensor:
        x = F.interpolate(x, size=(112, 112), mode='bilinear', align_corners=False)
        with torch.no_grad():
            embedding = self.iresnet.features(x)
            projected = self.proj_id(embedding)
            LN_projected = self.proj_id_norm(projected)
        return projected, LN_projected

    def get_relevance(self, projected_face_features: torch.tensor, features: torch.tensor) -> torch.tensor:
        x = torch.cat([projected_face_features, features], dim=1)
        return self.relevance_predictor(x)

    def get_fused_features(self, projected_face_features: torch.tensor, features: torch.tensor, relevance=None) -> torch.tensor:
        return (1 - relevance) * features + relevance * projected_face_features

    def classifier(self, features: torch.tensor) -> torch.tensor:
        return self.real_or_fake_head(features)

    def forward(self, data_dict: dict, inference=False) -> dict:
        
        # Backbone
        features = self.features(data_dict)
        
        
        # Forgery-Aware Identity Adapter (FAIA)
        projected_face_features, LN_projected_face_features = self.get_face_embedding(data_dict['image'])
        identity_cls = self.identity_real_fake_head(projected_face_features)

        # Identity-Aware Fusion Module (IAFM)
        LN_features = self.feat_norm(features)  # Apply LayerNorm here
        relevance = self.get_relevance(LN_projected_face_features, LN_features)
        fused_features = self.get_fused_features(LN_projected_face_features, LN_features, relevance)

        # Real or Fake Classifier
        cls = self.classifier(fused_features)

        prob = torch.softmax(cls, dim=1)[:, 1]

        return {
            'cls': cls,
            'identity_cls': identity_cls,
            'prob': prob,
            'feat': fused_features,
        }
        
        
def get_clip_visual(model_name = "openai/clip-vit-base-patch16"):
    processor = AutoProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name)
    return processor, model.vision_model
