import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import random


class CircularPad(nn.Module):
    def __init__(self, padding):
        super(CircularPad, self).__init__()
        self.padding = padding

    def forward(self, x):
        return F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='circular')


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, act_fn=nn.ReLU, norm_2d=nn.BatchNorm2d):
        super(DoubleConv, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Sequential(
            CircularPad(padding),
            nn.Conv2d(in_channels, out_channels, kernel_size),
            norm_2d(out_channels),
            act_fn(),
            CircularPad(padding),
            nn.Conv2d(out_channels, out_channels, kernel_size),
            norm_2d(out_channels),
            act_fn()
        )

    def forward(self, x):
        return self.conv(x)


class FluxNet_D(nn.Module):
    """
    Dual-Bounded Flux Network for conservative and bounded solute transport prediction.

    This network ensures:
    1. Mass conservation through flux-based prediction
    2. Lower and upper bounds through dual outflow/inflow approaches
    3. Learnable or fixed boundary parameters
    """

    def __init__(self,
                 in_channels=2,
                 base_channels=64,
                 num_blocks=4,
                 kernel_size=3,
                 act_fn=nn.GELU,
                 norm_2d=nn.BatchNorm2d,
                 neighborhood_size=15,
                 lower_bound=0.0,
                 upper_bound=1.0,
                 learnable_lower_bound=False,
                 learnable_upper_bound=False
                 ):
        super().__init__()
        self.num_blocks = num_blocks
        self.neighborhood_size = neighborhood_size
        self.learnable_lower_bound = learnable_lower_bound
        self.learnable_upper_bound = learnable_upper_bound

        if learnable_lower_bound:
            # 1. 计算 logit 值 (结果是一个 Tensor)
            logit_tensor = self._inverse_sigmoid(lower_bound)
            # 2. 使用 .data 来初始化 nn.Parameter，避免警告
            # .data 提供了底层数据，安全地创建了新的 Parameter
            self.lower_bound_logit = nn.Parameter(logit_tensor.data)
        else:
            # 1. 确保 lower_bound 是一个 Tensor
            bound_tensor = torch.as_tensor(lower_bound)
            # 2. 注册 Buffer，使用 .detach() 确保其不跟踪梯度
            self.register_buffer('lower_bound_value', bound_tensor.detach())

        if learnable_upper_bound:
            # 1. 计算 logit 值
            logit_tensor_upper = self._inverse_sigmoid(upper_bound)
            # 2. 使用 .data 初始化 nn.Parameter
            self.upper_bound_logit = nn.Parameter(logit_tensor_upper.data)
        else:
            # 1. 确保 upper_bound 是一个 Tensor
            bound_tensor_upper = torch.as_tensor(upper_bound)
            # 2. 注册 Buffer，使用 .detach()
            self.register_buffer('upper_bound_value', bound_tensor_upper.detach())

        # Number of neighbors (excluding center point itself)
        self.num_neighbors = neighborhood_size * neighborhood_size - 1

        # Total channels needed for dual approach: 2 sets of [1 outflow percentage + num_neighbors distribution ratios]
        self.total_channels = 2 * (1 + self.num_neighbors)

        # First convolution layer
        self.first_conv = nn.Sequential(
            CircularPad(kernel_size // 2),
            nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, padding=0),
            norm_2d(base_channels),
            act_fn()
        )

        # Residual blocks
        self.res_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.res_blocks.append(nn.ModuleList([
                DoubleConv(base_channels, base_channels, kernel_size, act_fn, norm_2d),
                nn.Conv2d(base_channels * 2, base_channels, kernel_size=1)
            ]))

        # Flux prediction layer for both outflow and inflow approaches
        self.flux_conv = nn.Conv2d(base_channels, self.total_channels, kernel_size=1)

        # Generate the neighbor offsets for the neighborhood
        radius = neighborhood_size // 2
        neighbor_offsets = []
        for i in range(-radius, radius + 1):
            for j in range(-radius, radius + 1):
                if i != 0 or j != 0:  # Exclude center point
                    neighbor_offsets.append((i, j))
        self.register_buffer('neighbor_offsets', torch.tensor(neighbor_offsets, dtype=torch.long))

    @staticmethod
    def _inverse_sigmoid(x, eps=1e-7):
        """Inverse sigmoid (logit) function for parameter initialization"""
        x = torch.clamp(torch.tensor(x), eps, 1 - eps)
        return torch.log(x / (1 - x))

    @property
    def lower_bound(self):
        """Get current lower bound value"""
        if self.learnable_lower_bound:
            return torch.sigmoid(self.lower_bound_logit)
        else:
            return self.lower_bound_value

    @property
    def upper_bound(self):
        """Get current upper bound value"""
        if self.learnable_upper_bound:
            return torch.sigmoid(self.upper_bound_logit)
        else:
            return self.upper_bound_value

    def forward(self, x):
        # Initial feature extraction
        features = self.first_conv(x)

        # Process through residual blocks
        for main_path, fusion_conv in self.res_blocks:
            identity = features
            features = main_path(features)
            features = torch.cat([features, identity], dim=1)
            features = fusion_conv(features)

        # Predict fluxes for both approaches
        raw_fluxes = self.flux_conv(features)  # [batch, total_channels, height, width]

        # Split the raw fluxes for outflow approach (lower bound)
        outflow_percentage = torch.sigmoid(raw_fluxes[:, 0:1])
        outflow_distribution_logits = raw_fluxes[:, 1:self.num_neighbors + 1]
        outflow_distribution_ratios = F.softmax(outflow_distribution_logits, dim=1)

        # Split the raw fluxes for inflow approach (upper bound)
        inflow_percentage = torch.sigmoid(raw_fluxes[:, self.num_neighbors + 1:self.num_neighbors + 2])
        inflow_distribution_logits = raw_fluxes[:, self.num_neighbors + 2:]
        inflow_distribution_ratios = F.softmax(inflow_distribution_logits, dim=1)

        # Compute next solute field using both approaches
        solute_field = x[:, 0:1]  # Assuming first channel is the solute concentration

        # Get current bounds
        lower_bound = self.lower_bound
        upper_bound = self.upper_bound

        # Compute changes from both approaches
        outflow_change, inflow_change = self._compute_transport(
            solute_field,
            outflow_percentage,
            outflow_distribution_ratios,
            inflow_percentage,
            inflow_distribution_ratios,
            lower_bound,
            upper_bound
        )

        # Average the changes from both approaches
        combined_change = (outflow_change + inflow_change) / 2

        # Apply the combined change to the input field
        next_field = solute_field + combined_change

        return next_field, outflow_change, inflow_change


    def _compute_transport(self, current_field, outflow_percentage, outflow_distribution_ratios,
                                    inflow_percentage, inflow_distribution_ratios, lower_bound, upper_bound):
        """
        Alternative highly optimized version using unfold/fold operations.

        This approach uses im2col-style operations for maximum parallelism.
        Note: This is more memory intensive but potentially faster for large neighborhoods.
        """
        batch_size, _, height, width = current_field.shape
        radius = self.neighborhood_size // 2

        # ------ Outflow approach ------
        available_for_outflow = current_field - lower_bound
        outflow_amount = available_for_outflow * outflow_percentage
        outflow_change = -outflow_amount
        # Pre-compute all flows
        outflow_to_all = outflow_amount * outflow_distribution_ratios  # [B, num_neighbors, H, W]

        available_for_inflow = upper_bound - current_field
        inflow_amount = available_for_inflow * inflow_percentage
        inflow_change = inflow_amount
        inflow_from_all = inflow_amount * inflow_distribution_ratios  # [B, num_neighbors, H, W]

        # Vectorized shift and accumulate
        for n, (dh, dw) in enumerate(self.neighbor_offsets):
            # Outflow
            shifted_out = torch.roll(outflow_to_all[:, n:n + 1], shifts=(-dh, -dw), dims=(2, 3))
            outflow_change = outflow_change + shifted_out

            # Inflow
            shifted_in = torch.roll(inflow_from_all[:, n:n + 1], shifts=(dh, dw), dims=(2, 3))
            inflow_change = inflow_change - shifted_in

        return outflow_change, inflow_change



