import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
from torch.nn.parameter import Parameter
import torch.fft
import time
start=time.time()

class MatReader:
    def __init__(self, file_path):
        self.data = loadmat(file_path)
    def read_field(self, field):
        return torch.tensor(self.data[field], dtype=torch.float32)


class DualityMap(nn.Module):
    def __init__(self, p=1.5, eps=1e-8):
        super().__init__()
        self.p = p
        self.eps = eps
    def forward(self, u):
        norm_p = torch.mean(torch.abs(u) ** self.p, dim=(-2, -1), keepdim=True)
        scale = (norm_p + self.eps).pow((self.p - 2) / self.p)
        return scale * u.abs().pow(self.p - 2) * u


class PolarSpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2, Nr=64):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(scale * torch.rand(in_channels, out_channels, modes1, modes2, 2))
        N = Nr - 1
        j = torch.arange(0, Nr, dtype=torch.float32)
        k = torch.arange(0, modes1, dtype=torch.float32)
        cos_term = torch.cos(math.pi * torch.outer(j, k) / N)
        norm = torch.ones(modes1)
        norm[0] = math.sqrt(1.0 / N)
        norm[1:] = math.sqrt(2.0 / N)
        self.register_buffer('inverse_matrix', cos_term * norm.unsqueeze(0))
        self.register_buffer('forward_matrix', self.inverse_matrix.t())

    def forward(self, x):
        B, C, Nr, Ntheta = x.shape
        x_ft = torch.fft.rfft(x, dim=-1, norm='ortho')
        modes_theta = x_ft.shape[-1]
        out_ft = torch.zeros(B, self.out_channels, Nr, modes_theta, dtype=torch.cfloat, device=x.device)
        ct_real = torch.einsum('bcrm,kr->bckm', x_ft.real[:, :, :, :self.modes2], self.forward_matrix)
        ct_imag = torch.einsum('bcrm,kr->bckm', x_ft.imag[:, :, :, :self.modes2], self.forward_matrix)
        x_ct = torch.complex(ct_real, ct_imag)
        out_ct = torch.einsum('bikm,iokm->bokm', x_ct, torch.view_as_complex(self.weights))
        ft_real = torch.einsum('bokm,rk->borm', out_ct.real, self.inverse_matrix)
        ft_imag = torch.einsum('bokm,rk->borm', out_ct.imag, self.inverse_matrix)
        out_ft[:, :, :, :self.modes2] = torch.complex(ft_real, ft_imag)
        return torch.fft.irfft(out_ft, n=Ntheta, dim=-1, norm='ortho')


class DualPolarSpectralConv2d(nn.Module):
    def __init__(self, ch, modes1, modes2, Nr=64, p=1.5):
        super().__init__()
        self.primal_conv = PolarSpectralConv2d(ch, ch, modes1, modes2, Nr)
        self.J = DualityMap(p)
    def forward(self, u_p):
        y_p = self.primal_conv(u_p)
        y_d = self.J(y_p)
        return y_p, y_d


class Propagation(nn.Module):
    def __init__(self, channels, modes1=12, modes2=12, Nr=64, p=1.5):
        super().__init__()
        self.spec = DualPolarSpectralConv2d(channels, modes1, modes2, Nr, p)
        self.conv = nn.Conv2d(channels, channels, 1)
        self.norm = nn.BatchNorm2d(channels)
        self.J = DualityMap(p)
    def forward(self, u_p, u_d):
        s_p, s_d = self.spec(u_p)
        c_p = self.conv(u_p)
        c_d = self.J(c_p)
        z_p = F.gelu(self.norm(u_p + s_p + c_p))
        z_d = self.J(z_p) + s_d + c_d
        return z_p, z_d


class DynamicConv(nn.Module):
    def __init__(self, latent_dim, in_channels, out_channels, k, rhos):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, in_channels * out_channels * k * k)
        )
        self.k = k
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.rhos = rhos

    def forward(self, x, z_dual):
        B = x.size(0)
        kernel = self.net(z_dual).view(B * self.out_channels, self.in_channels, self.k, self.k)
        x = x.view(1, B * self.in_channels, *x.shape[2:])
        out = F.conv2d(x, kernel, padding=self.k // 2, groups=B)
        out = out.view(B, self.out_channels, *out.shape[2:])
        for i, rho in enumerate(self.rhos):
            out[:, i] = out[:, i] * rho
        return out


class AFDONetinv(nn.Module):
    def __init__(self, input_dim, inter_dim, latent_dim, rhos, height=64, width=64, fno_channels=32, p=1.5):
        super().__init__()
        self.enc_p = nn.Sequential(nn.Linear(input_dim, inter_dim), nn.ReLU(), nn.Linear(inter_dim, 2 * latent_dim))
        self.enc_d = nn.Sequential(nn.Linear(input_dim, inter_dim), nn.ReLU(), nn.Linear(inter_dim, 2 * latent_dim))
        self.fc_p = nn.Linear(latent_dim, fno_channels * height * width)
        self.fc_d = nn.Linear(latent_dim, fno_channels * height * width)
        self.fno = nn.ModuleList([Propagation(fno_channels, p=p) for _ in range(3)])
        self.head = nn.Conv2d(fno_channels, 2, 1)
        self.dynamic = DynamicConv(latent_dim, 2, 2, 3, rhos)

    @staticmethod
    def _repar(mu, logv):
        logv = torch.clamp(logv, min=-10.0, max=10.0)
        eps = torch.randn_like(mu)
        return mu + eps * torch.exp(0.5 * logv)

    def forward(self, x):
        B = x.size(0)
        x_flat = x.view(B, -1)
        mu_p, logv_p = self.enc_p(x_flat).chunk(2, dim=1)
        mu_d, logv_d = self.enc_d(x_flat).chunk(2, dim=1)
        z_p = self._repar(mu_p, logv_p)
        z_d = self._repar(mu_d, logv_d)
        f_p = self.fc_p(z_p).view(B, 32, 64, 64)
        f_d = self.fc_d(z_d).view(B, 32, 64, 64)
        f_d = DualityMap()(f_d)
        for blk in self.fno:
            f_p, f_d = blk(f_p, f_d)
        out_p = self.head(f_p)
        return self.dynamic(out_p, z_d), (mu_p, logv_p), (mu_d, logv_d)


def load_and_prepare_darcy_data(file_path, feature_dim, resize_to=64, ntrain=30, ntest=20, rands=2):
    reader = MatReader(file_path)
    sol = reader.read_field('sol').reshape(-1, 21, 21)
    f = reader.read_field('source').reshape(-1, 21, 21)
    sol = sol.unsqueeze(1).repeat(1, 2, 1, 1)
    f = f.unsqueeze(1).repeat(1, 2, 1, 1)
    sol = F.interpolate(sol, size=(resize_to, resize_to), mode='bilinear', align_corners=False)
    f = F.interpolate(f, size=(resize_to, resize_to), mode='bilinear', align_corners=False)

    mean_f = f.mean()
    std_f = f.std()

    x_list, y_list = [], []
    sample_per_task = 100
    for t in range(ntrain + ntest):
        for _ in range(rands):
            idx = t * sample_per_task + torch.randperm(sample_per_task)[:feature_dim]
            x_list.append(sol[idx])
            y_list.append(f[idx])

    x = torch.cat(x_list)
    y = torch.cat(y_list)

    x_tr = x[:ntrain * rands * feature_dim]
    y_tr = y[:ntrain * rands * feature_dim]
    x_te = x[ntrain * rands * feature_dim:]
    y_te = y[ntrain * rands * feature_dim:]

    # Normalize AFTER splitting
    x_mean, x_std = x_tr.mean(), x_tr.std()
    x_tr = (x_tr - x_mean) / x_std
    x_te = (x_te - x_mean) / x_std
    y_tr = (y_tr - mean_f) / std_f
    y_te = (y_te - mean_f) / std_f

    return x_tr, y_tr, x_te, y_te, mean_f, std_f

def evaluate(model, loader, device, mean_f, std_f):
    model.eval()
    total_mae = 0
    total_mse = 0
    total_l2 = 0
    num_elements_per_sample = None
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            pred, *_ = model(xb)
            if num_elements_per_sample is None:
                num_elements_per_sample = yb[0].numel()
            
            total_mae += F.l1_loss(pred, yb, reduction='sum').item()
            total_mse += F.mse_loss(pred, yb, reduction='sum').item()
            
            pred = pred * std_f + mean_f
            yb = yb * std_f + mean_f
            total_l2 += torch.norm(pred - yb).item() ** 2 / torch.norm(yb).item() ** 2
    N = len(loader.dataset)
    mae = total_mae / (N * num_elements_per_sample)
    mse = total_mse / (N * num_elements_per_sample)
    print(f"Test MAE: {mae:.6f}, MSE: {mse:.6f}, Rel L2: {total_l2 / len(loader):.6f}")


def main():
    file_path = 'DarcyStatic_A100F100_10000x21x21_chi_sol_source.mat'
    x_tr, y_tr, x_te, y_te, mean_f, std_f = load_and_prepare_darcy_data(file_path, feature_dim=100)
    train_loader = DataLoader(TensorDataset(x_tr, y_tr), batch_size=32, shuffle=True)
    test_loader = DataLoader(TensorDataset(x_te, y_te), batch_size=32)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_dim = x_tr[0].numel()
    model = AFDONetinv(input_dim, 512, 128, rhos=[0.8, 0.2]).to(device)
    print(f"Total model params: {sum(p.numel() for p in model.parameters())}")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.L1Loss()

    for epoch in range(1001):
        model.train()
        total_loss = 0.0
        beta = 1e-8
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred, (mu_p, lv_p), (mu_d, lv_d) = model(xb)
            rec = loss_fn(pred, yb)
            kl_p = -0.5 * torch.mean(1 + lv_p - mu_p**2 - lv_p.exp())
            kl_d = -0.5 * torch.mean(1 + lv_d - mu_d**2 - lv_d.exp())
            loss = rec + beta * (kl_p + kl_d)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch:02d} | Train Loss: {total_loss / len(train_loader):.6f} | Rec: {rec.item():.4f} | KL_p: {kl_p.item():.4f} | KL_d: {kl_d.item():.4f}")

    evaluate(model, test_loader, device, mean_f, std_f)


if __name__ == "__main__":
    main()
end=time.time()
print(end-start)
