import torch.nn as nn

class EvidenceHead(nn.Module):

    def __init__(self, input_dim, num_classes):
        super(EvidenceHead, self).__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

class FeatureBranch(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super(FeatureBranch, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5), 

        )
        self.head = EvidenceHead(hidden_dim, num_classes)

    def forward(self, x):
        # Input: [Batch, input_dim]
        feat = self.net(x)
        logits = self.head(feat)
        return logits


class MultiViewNet(nn.Module):
    def __init__(self, dataset_name, num_classes):
        super(MultiViewNet, self).__init__()
        self.dataset_name = dataset_name.lower()
        self.num_classes = num_classes
        
        print(f"[Model Init] Creating MultiViewNet for {self.dataset_name} ({num_classes} classes)...")
        
        if self.dataset_name == 'ave':
            self.branch1 = FeatureBranch(input_dim=512, num_classes=num_classes)
            self.branch2 = FeatureBranch(input_dim=128, num_classes=num_classes)
        else:
            raise ValueError(f"Unknown dataset name: {dataset_name}. Supported: sunrgbd, ave, chexpert, mura")

    def forward(self, batch):

        logits_list = []
        
        l1 = self.branch1(batch['view1'])
        l2 = self.branch2(batch['view2'])
        
        logits_list.append(l1)
        logits_list.append(l2)
            
        return logits_list