#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from rfm import compute_agop_sqrt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class AttentionLayer(nn.Module):
    def __init__(self):
        super(AttentionLayer, self).__init__()

    def forward(self, features, W_U, b_U):
        out = F.linear(features, W_U, b_U)
        out_sigmoid = torch.sigmoid(out)
        return out_sigmoid


class AGOPMIPL(nn.Module):
    def __init__(self, args):
        super(MIPML_V12, self).__init__()
        self.args = args
        self.L = self.args.L
        self.D = 128
        self.K = 1
        self.nr_fea = self.args.nr_fea
        self.nr_fea_sqrt = math.floor(math.sqrt(self.nr_fea))
        
        # Configurable options
        self.proto_agg = getattr(args, 'proto_agg', 'mean')
        self.inst_weight = getattr(args, 'inst_weight', 0.5)
        self.attn_lambda = getattr(args, 'attn_lambda', 0.3)
        
        print(f"[AGOPMIPL] Initialized with L={self.L}, attn_lambda={self.attn_lambda}")

        # AGOP feature transformation buffers
        self.register_buffer('agop_matrix', torch.eye(self.L))
        self.register_buffer('agop_sqrt', torch.eye(self.L))

        # Feature Extractor (no BatchNorm for RFM compatibility)
        is_perfect_square = (self.nr_fea_sqrt * self.nr_fea_sqrt == self.nr_fea)
        self.use_cnn = (self.nr_fea_sqrt >= 8) and is_perfect_square
        
        if self.use_cnn:
            # Wide LeNet: 32/64 channels, padding to preserve spatial info, NO BatchNorm
            self.feature_extractor_part1 = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=5, padding=2),  # 28 -> 28
                nn.ReLU(),
                nn.MaxPool2d(2, stride=2),  # 28 -> 14
                nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 14 -> 14
                nn.ReLU(),
                nn.MaxPool2d(2, stride=2)  # 14 -> 7
            )
            # Output: 64 * 7 * 7 = 3136
            self.cnn_output_size = 64 * 7 * 7
            print(f"[AGOPMIPL] Wide LeNet (32/64), output size: {self.cnn_output_size}")
            
            self.feature_extractor_part2 = nn.Sequential(
                nn.Linear(self.cnn_output_size, self.L),
                nn.Dropout(0.4),
                nn.ReLU(),
            )
        else:
            self.feature_extractor_part1 = None
            self.feature_extractor_part2 = nn.Sequential(
                nn.Linear(self.nr_fea, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, self.L),
                nn.ReLU(),
            )
            print(f"[AGOPMIPL] FC for {self.nr_fea}-dim input")

        # Prototype aggregation (linear)
        self.linear1 = nn.Sequential(nn.Linear(self.args.nr_class, 1))

        # Shared Attention layers (used for both RFM and raw paths)
        self.att_layer_V = AttentionLayer()
        self.att_layer_U = AttentionLayer()
        self.linear_V = nn.Linear(self.L * self.K, self.args.nr_class)
        self.linear_U = nn.Linear(self.L * self.K, 1)
        self.attention_weights = nn.Sequential(
            nn.Linear(self.args.nr_class, self.D),
            nn.ReLU(),
            nn.Linear(self.D, self.K),
        )

        # Classifiers
        self.classifier = nn.Linear(self.L, self.args.nr_class)
        self.instance_classifier = nn.Linear(self.L, self.args.nr_class)

    def compute_attention_scores(self, H, H2):
        """Compute raw attention scores (before softmax) using shared attention network"""
        A_V = self.att_layer_V(H, self.linear_V.weight, self.linear_V.bias)
        A_U = self.att_layer_U(H2, self.linear_U.weight, self.linear_U.bias)
        A = self.attention_weights(A_V * A_U.T)
        A = torch.transpose(A, 1, 0)
        return A  # Raw scores, no softmax yet

    def forward(self, X, Z):
        X = X.squeeze(0)
        
        # Feature extraction (raw)
        if self.use_cnn:
            H_raw = self.feature_extractor_part1(X)
            H_raw = H_raw.view(-1, self.cnn_output_size)
            H_raw = self.feature_extractor_part2(H_raw)
        else:
            H_raw = X.view(X.shape[0], -1)
            H_raw = self.feature_extractor_part2(H_raw)
        
        # RFM Transformation
        H_rfm = H_raw @ self.agop_sqrt
        
        # Prototype feature extraction
        H2_raw_list = []
        H2_rfm_list = []
        for i in range(len(Z)):
            z_class = Z[i]
            if self.use_cnn:
                if z_class.dim() == 2:
                    z_class = z_class.view(-1, 1, self.nr_fea_sqrt, self.nr_fea_sqrt)
                elif z_class.dim() == 3:
                    z_class = z_class.unsqueeze(1)
                h2_raw = self.feature_extractor_part1(z_class)
                h2_raw = h2_raw.view(-1, self.cnn_output_size)
                h2_raw = self.feature_extractor_part2(h2_raw)
            else:
                z_class = z_class.view(z_class.shape[0], -1)
                h2_raw = self.feature_extractor_part2(z_class)
            
            h2_rfm = h2_raw @ self.agop_sqrt
            H2_raw_list.append(h2_raw)
            H2_rfm_list.append(h2_rfm)
        
        # Prototype Aggregation
        if self.proto_agg == 'mean':
            H2_raw = torch.stack([torch.mean(h, dim=0) for h in H2_raw_list], dim=0)
            H2_rfm = torch.stack([torch.mean(h, dim=0) for h in H2_rfm_list], dim=0)
        else:
            # Linear aggregation for RFM path
            H2_tensor = torch.stack(H2_rfm_list, dim=0)
            H2_tensor = H2_tensor.permute(1, 2, 0)
            H2_tensor = H2_tensor.reshape(-1, self.args.nr_class)
            H2_rfm = self.linear1(H2_tensor).view(-1, self.L)
            
            # Linear aggregation for raw path
            H2_tensor = torch.stack(H2_raw_list, dim=0)
            H2_tensor = H2_tensor.permute(1, 2, 0)
            H2_tensor = H2_tensor.reshape(-1, self.args.nr_class)
            H2_raw = self.linear1(H2_tensor).view(-1, self.L)

        # === DUAL-PATH ATTENTION ===
        # Path 1: RFM-transformed features
        A_rfm = self.compute_attention_scores(H_rfm, H2_rfm)
        
        # Path 2: Raw features (no RFM)
        A_raw = self.compute_attention_scores(H_raw, H2_raw)
        
        # Combine: A_final = A_rfm + λ * A_raw
        A_combined = A_rfm + self.attn_lambda * A_raw
        
        # Normalize
        A = F.softmax(A_combined / math.sqrt(self.L), dim=1)
        
        # Bag representation (use raw features)
        Z_bag = torch.mm(A, H_raw)
        
        # Classification
        Y_logits = self.classifier(Z_bag)
        H_logits = self.instance_classifier(H_rfm)
        H_prob = F.softmax(H_logits, dim=1)

        return Y_logits, A, H_prob, H2_rfm, H_logits

    def update_agop(self, agop_matrix, momentum=0.9):
        """Update AGOP matrix with new computation."""
        self.agop_matrix = momentum * self.agop_matrix + (1 - momentum) * agop_matrix
        self.agop_sqrt = compute_agop_sqrt(self.agop_matrix)

    def full_loss(self, prediction, target, args, H_logits=None):
        if target.dim() == 1:
            target = target.unsqueeze(0)
        
        Y_candidate = (target > 0).float()
        prediction_can = prediction * Y_candidate
        prediction_can_sum = prediction_can.sum(dim=1, keepdim=True).clamp(min=1e-10)
        new_prediction = prediction_can / prediction_can_sum
        
        mp_loss = -target * torch.log(prediction + 1e-10)
        sp_loss = -torch.sum(new_prediction * torch.log(new_prediction + 1e-10))
        
        Y_non_candidate = 1.0 - Y_candidate
        prediction_non = prediction * Y_non_candidate
        neg_prediction = (1.0 - prediction_non) * Y_non_candidate + Y_candidate
        neg_prediction = torch.clamp(neg_prediction, min=1e-10)
        in_loss = -Y_non_candidate * torch.log(neg_prediction)
        
        bag_loss = torch.sum(mp_loss) + args.mu * sp_loss + args.gamma * torch.sum(in_loss)

        if H_logits is not None and self.inst_weight > 0:
            target_expanded = target.repeat(H_logits.shape[0], 1)
            target_dist = target_expanded / target_expanded.sum(dim=1, keepdim=True).clamp(min=1e-10)
            H_log_prob = F.log_softmax(H_logits, dim=1)
            instance_loss = -(target_dist * H_log_prob).sum() / H_logits.shape[0]
            total_loss = bag_loss + self.inst_weight * instance_loss
        else:
            total_loss = bag_loss
        
        return new_prediction, total_loss

    def calculate_objective(self, X, Y, Z, args):
        Y = Y.reshape(-1)
        Y_logits, A, H_prob, _, H_logits = self.forward(X, Z)
        Y_prob = F.softmax(Y_logits, dim=1)
        new_prob, loss = self.full_loss(Y_prob, Y, args, H_logits)
        return loss, new_prob, A, H_prob

    def evaluate_objective(self, X, Z):
        Y_logits, _, _, _, _ = self.forward(X, Z)
        Y_prob = F.softmax(Y_logits, dim=1)
        return Y_prob
    
    def regenerate_soft_labels(self, original_s, model_pred, alpha):
        candidate_mask = (original_s > 0).float()
        filtered_pred = model_pred * candidate_mask
        filtered_pred = filtered_pred / (filtered_pred.sum(dim=-1, keepdim=True) + 1e-10)
        new_s = alpha * original_s + (1 - alpha) * filtered_pred
        new_s = new_s / (new_s.sum(dim=-1, keepdim=True) + 1e-10)
        return new_s
