import torch
from neuralop.models import FNO
import matplotlib.pyplot as plt
from neuralop.data.datasets import load_MHD
import os
from scipy import io as sio
import numpy as np
import tqdm
import json
import pandas as pd

device = 'cuda'

res_dict = {
    'RMSE': {'Jx': 0, 'Jy': 0, 'Jz': 0, 'u': 0, 'v': 0},
    'nRMSE': {'Jx': 0, 'Jy': 0, 'Jz': 0, 'u': 0, 'v': 0},
    'MaxError': {'Jx': 0, 'Jy': 0, 'Jz': 0, 'u': 0, 'v': 0},
    'fRMSE': {},
    'bRMSE': {'Jx': 0, 'Jy': 0, 'Jz': 0, 'u': 0, 'v': 0},
}

if __name__ == '__main__':
    save_dir = './eval_results/MHD'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    range_allBr_paths = "./DiffusionPDE/data/training/MHD/Br/range_allBr.mat"
    range_allBr = sio.loadmat(range_allBr_paths)['range_allBr']

    max_Br = range_allBr[0, 1]
    min_Br = range_allBr[0, 0]

    # load max_Jx min_Jx
    range_allJx_paths = "./DiffusionPDE/data/training/MHD/Jx/range_allJx.mat"
    range_allJx = sio.loadmat(range_allJx_paths)['range_allJx']

    max_Jx = range_allJx[0, 1]
    min_Jx = range_allJx[0, 0]

    # load max_Jy min_Jy
    range_allJy_paths = "./DiffusionPDE/data/training/MHD/Jy/range_allJy.mat"
    range_allJy = sio.loadmat(range_allJy_paths)['range_allJy']

    max_Jy = range_allJy[0, 1]
    min_Jy = range_allJy[0, 0]

    # load max_Jz min_Jz
    range_allJz_paths = "./DiffusionPDE/data/training/MHD/Jz/range_allJz.mat"
    range_allJz = sio.loadmat(range_allJz_paths)['range_allJz']

    max_Jz = range_allJz[0, 1]
    min_Jz = range_allJz[0, 0]

    # load max_u_u min_u_u
    range_allu_u_paths = "./DiffusionPDE/data/training/MHD/u_u/range_allu_u.mat"
    range_allu_u = sio.loadmat(range_allu_u_paths)['range_allu_u']

    max_u_u = range_allu_u[0, 1]
    min_u_u = range_allu_u[0, 0]

    # load max_u_v min_u_v
    range_allu_v_paths = "./DiffusionPDE/data/training/MHD/u_v/range_allu_v.mat"
    range_allu_v = sio.loadmat(range_allu_v_paths)['range_allu_v']

    max_u_v = range_allu_v[0, 1]
    min_u_v = range_allu_v[0, 0]

    # Let's load the TE_heat dataset.
    train_loader, test_loaders, data_processor = load_MHD(
            n_train=10000, batch_size=16,
            test_resolutions=[128], n_tests=[100],
            test_batch_sizes=[64],
    )
    data_processor = data_processor.to(device)
    data_processor.eval()

    model = FNO(n_modes=(12, 12),
                 in_channels=1,
                 out_channels=5,
                 hidden_channels=128,
                 projection_channel_ratio=2)
    model = model.to(device)

    model.max_Jx = max_Jx
    model.min_Jx = min_Jx
    model.max_Jy = max_Jy
    model.min_Jy = min_Jy
    model.max_Jz = max_Jz
    model.min_Jz = min_Jz
    model.max_u_u = max_u_u
    model.min_u_u = min_u_u
    model.max_u_v = max_u_v
    model.min_u_v = min_u_v

    model.load_state_dict(torch.load("./checkpoints/MHD/5/model_epoch_15_state_dict.pt", weights_only=False))
    print("Model weights loaded from model_weights.pt")

    # 将模型设置为评估模式
    model.eval()

    u_metric_total, v_metric_total, T_metric_total, sample_total = 0, 0, 0, 0


    def get_nRMSE():
        Jx, Jy, Jz, u, v = pred

        Jx_metric = torch.norm(Jx - outputs[:, 0, :, :], 2, dim=(1, 2)) / torch.norm(outputs[:, 0, :, :], 2, dim=(1, 2))
        Jy_metric = torch.norm(Jy - outputs[:, 1, :, :], 2, dim=(1, 2)) / torch.norm(outputs[:, 1, :, :], 2, dim=(1, 2))
        Jz_metric = torch.norm(Jz - outputs[:, 2, :, :], 2, dim=(1, 2)) / torch.norm(outputs[:, 2, :, :], 2, dim=(1, 2))
        u_metric = torch.norm(u - outputs[:, 3, :, :], 2, dim=(1, 2)) / torch.norm(outputs[:, 3, :, :], 2, dim=(1, 2))
        v_metric = torch.norm(v - outputs[:, 4, :, :], 2, dim=(1, 2)) / torch.norm(outputs[:, 4, :, :], 2, dim=(1, 2))

        res_dict['nRMSE']['Jx'] += Jx_metric.sum()
        res_dict['nRMSE']['Jy'] += Jy_metric.sum()
        res_dict['nRMSE']['Jz'] += Jz_metric.sum()
        res_dict['nRMSE']['u'] += u_metric.sum()
        res_dict['nRMSE']['v'] += v_metric.sum()


    def get_RMSE():
        Jx, Jy, Jz, u, v = pred  # pred是模型预测值 (B, C, H, W)
        # 计算各通道RMSE（按batch和空间维度平均）
        Jx_metric = torch.sqrt(torch.mean((Jx - outputs[:, 0, :, :]) ** 2, dim=(1, 2)))

        Jy_metric = torch.sqrt(torch.mean((Jy - outputs[:, 1, :, :]) ** 2, dim=(1, 2)))
        Jz_metric = torch.sqrt(torch.mean((Jz - outputs[:, 2, :, :]) ** 2, dim=(1, 2)))
        u_metric = torch.sqrt(torch.mean((u - outputs[:, 3, :, :]) ** 2, dim=(1, 2)))
        v_metric = torch.sqrt(torch.mean((v - outputs[:, 4, :, :]) ** 2, dim=(1, 2)))
        # 累加到结果字典
        res_dict['RMSE']['Jx'] += Jx_metric.sum()
        res_dict['RMSE']['Jy'] += Jy_metric.sum()
        res_dict['RMSE']['Jz'] += Jz_metric.sum()
        res_dict['RMSE']['u'] += u_metric.sum()
        res_dict['RMSE']['v'] += v_metric.sum()


    def get_MaxError():
        Jx, Jy, Jz, u, v = pred
        # 计算各通道的绝对误差最大值（沿空间维度）
        Jx_metric = torch.abs(Jx - outputs[:, 0, :, :]).flatten(1).max(dim=1)[0]  # 先展平再求max
        Jy_metric = torch.abs(Jy - outputs[:, 1, :, :]).flatten(1).max(dim=1)[0]  # 先展平再求max
        Jz_metric = torch.abs(Jz - outputs[:, 2, :, :]).flatten(1).max(dim=1)[0]  # 先展平再求max
        u_metric = torch.abs(u - outputs[:, 3, :, :]).flatten(1).max(dim=1)[0]  # 先展平再求max
        v_metric = torch.abs(v - outputs[:, 4, :, :]).flatten(1).max(dim=1)[0]  # 先展平再求max
        # 累加结果
        res_dict['MaxError']['Jx'] += Jx_metric.sum()
        res_dict['MaxError']['Jy'] += Jy_metric.sum()
        res_dict['MaxError']['Jz'] += Jz_metric.sum()
        res_dict['MaxError']['u'] += u_metric.sum()
        res_dict['MaxError']['v'] += v_metric.sum()


    def get_bRMSE():
        Jx, Jy, Jz, u, v = pred
        # 提取边界像素（上下左右各1像素）
        boundary_mask = torch.zeros_like(outputs[:, 0, :, :], dtype=bool)
        boundary_mask[:, 0, :] = True  # 上边界
        boundary_mask[:, -1, :] = True  # 下边界
        boundary_mask[:, :, 0] = True  # 左边界
        boundary_mask[:, :, -1] = True  # 右边界

        # 计算边界RMSE
        Jx_boundary_pred = Jx[boundary_mask].view(Jx.shape[0], -1)
        Jx_boundary_true = outputs[:, 0, :, :][boundary_mask].view(Jx.shape[0], -1)
        Jx_metric = torch.sqrt(torch.mean((Jx_boundary_pred - Jx_boundary_true) ** 2, dim=1))
        res_dict['bRMSE']['Jx'] += Jx_metric.sum()

        Jy_boundary_pred = Jy[boundary_mask].view(Jy.shape[0], -1)
        Jy_boundary_true = outputs[:, 1, :, :][boundary_mask].view(Jy.shape[0], -1)
        Jy_metric = torch.sqrt(torch.mean((Jy_boundary_pred - Jy_boundary_true) ** 2, dim=1))
        res_dict['bRMSE']['Jy'] += Jy_metric.sum()

        Jz_boundary_pred = Jz[boundary_mask].view(Jz.shape[0], -1)
        Jz_boundary_true = outputs[:, 2, :, :][boundary_mask].view(Jz.shape[0], -1)
        Jz_metric = torch.sqrt(torch.mean((Jz_boundary_pred - Jz_boundary_true) ** 2, dim=1))
        res_dict['bRMSE']['Jz'] += Jz_metric.sum()

        u_boundary_pred = u[boundary_mask].view(u.shape[0], -1)
        u_boundary_true = outputs[:, 3, :, :][boundary_mask].view(u.shape[0], -1)
        u_metric = torch.sqrt(torch.mean((u_boundary_pred - u_boundary_true) ** 2, dim=1))
        res_dict['bRMSE']['u'] += u_metric.sum()

        v_boundary_pred = v[boundary_mask].view(v.shape[0], -1)
        v_boundary_true = outputs[:, 4, :, :][boundary_mask].view(v.shape[0], -1)
        v_metric = torch.sqrt(torch.mean((v_boundary_pred - v_boundary_true) ** 2, dim=1))
        res_dict['bRMSE']['v'] += v_metric.sum()


    def get_fRMSE():
        Jx, Jy, Jz, u, v = pred  # pred形状: (Batch, Channel, Height, Width)

        # 初始化结果存储
        for freq_band in ['low', 'middle', 'high']:
            res_dict['fRMSE'][f'Jx_{freq_band}'] = 0.0
            res_dict['fRMSE'][f'Jy_{freq_band}'] = 0.0
            res_dict['fRMSE'][f'Jz_{freq_band}'] = 0.0
            res_dict['fRMSE'][f'u_{freq_band}'] = 0.0
            res_dict['fRMSE'][f'v_{freq_band}'] = 0.0

        # 定义频段范围 (基于论文设置)
        freq_bands = {
            'low': (0, 4),  # k_min=0, k_max=4
            'middle': (5, 12),  # k_min=5, k_max=12
            'high': (13, None)  # k_min=13, k_max=∞ (实际取Nyquist频率)
        }

        def compute_band_fft(pred_fft, true_fft, k_min, k_max, H, W):
            """计算指定频段的fRMSE"""
            # 生成频段掩码
            kx = torch.arange(H, device=pred_fft.device)
            ky = torch.arange(W, device=pred_fft.device)
            kx, ky = torch.meshgrid(kx, ky, indexing='ij')

            # 计算径向波数 (避免重复计算0和Nyquist频率)
            r = torch.sqrt(kx ** 2 + ky ** 2)
            if k_max is None:
                mask = (r >= k_min)
                k_max = max(H // 2, W // 2)  # nyquist
            else:
                mask = (r >= k_min) & (r <= k_max)

            # 计算误差
            diff_fft = torch.abs(pred_fft - true_fft) ** 2
            band_error = diff_fft[:, mask].sum(dim=1)  # 对空间维度
            band_error = torch.sqrt(band_error) / (k_max - k_min + 1)
            return band_error

        # 对每个通道计算fRMSE
        for channel_idx, (pred_ch, true_ch, name) in enumerate([
            (Jx, outputs[:, 0, :, :], 'Jx'),
            (Jy, outputs[:, 1, :, :], 'Jy'),
            (Jz, outputs[:, 2, :, :], 'Jz'),
            (u, outputs[:, 3, :, :], 'u'),
            (v, outputs[:, 4, :, :], 'v')
        ]):
            # 傅里叶变换 (shift后低频在中心)
            pred_fft = torch.fft.fft2(pred_ch)
            true_fft = torch.fft.fft2(true_ch)
            H, W = pred_ch.shape[-2], pred_ch.shape[-1]

            # 计算各频段
            for band, (k_min, k_max) in freq_bands.items():
                error = compute_band_fft(pred_fft, true_fft, k_min, k_max, H, W)
                res_dict['fRMSE'][f'{name}_{band}'] += error.sum()

    #开始测试
    for idx, sample in enumerate(tqdm.tqdm(test_loaders[128])):
        with torch.no_grad():
            sample = data_processor.preprocess(sample)
            inputs,outputs = sample['x'].to(device),sample['y'].to(device).squeeze()
            pred_outputs = model(inputs)
            pred_outputs, _ = data_processor.postprocess(pred_outputs)
            pred_outputs = pred_outputs.squeeze()

            # GT 反归一化
            # outputs[:, 0, :, :] = ((outputs[:, 0, :, :] + 0.9) / 1.8 * (model.max_Jx - model.min_Jx) + model.min_Jx).to(torch.float64)
            outputs[:, 1, :, :] = ((outputs[:, 1, :, :] + 0.9) / 1.8 * (model.max_Jx - model.min_Jx) + model.min_Jx).to(torch.float64)
            outputs[:, 2, :, :] = ((outputs[:, 2, :, :] + 0.9) / 1.8 * (model.max_Jy - model.min_Jy) + model.min_Jy).to(torch.float64)
            outputs[:, 3, :, :] = ((outputs[:, 3, :, :] + 0.9) / 1.8 * (model.max_u_u - model.min_u_u) + model.min_u_u).to(torch.float64)
            outputs[:, 4, :, :] = ((outputs[:, 4, :, :] + 0.9) / 1.8 * (model.max_u_v - model.min_u_v) + model.min_u_v).to(torch.float64)

            # pred 反归一化
            pred_outputs[:, 0, :, :] = ((pred_outputs[:, 0, :, :] + 0.9) / 1.8 * (model.max_Jx - model.min_Jx) + model.min_Jx).to(
                torch.float64)
            pred_outputs[:, 1, :, :] = ((pred_outputs[:, 1, :, :] + 0.9) / 1.8 * (model.max_Jy - model.min_Jy) + model.min_Jx).to(
                torch.float64)
            pred_outputs[:, 2, :, :] = ((pred_outputs[:, 2, :, :] + 0.9) / 1.8 * (model.max_Jz - model.min_Jz) + model.min_Jz).to(
                torch.float64)
            pred_outputs[:, 3, :, :] = (
                        (pred_outputs[:, 3, :, :] + 0.9) / 1.8 * (model.max_u_u - model.min_u_u) + model.min_u_u).to(
                torch.float64)
            pred_outputs[:, 4, :, :] = (
                        (pred_outputs[:, 4, :, :] + 0.9) / 1.8 * (model.max_u_v - model.min_u_v) + model.min_u_v).to(
                torch.float64)

            pred = (pred_outputs[:, 0, :, :], pred_outputs[:, 1, :, :], pred_outputs[:, 2, :, :], pred_outputs[:, 3, :, :], pred_outputs[:, 4, :, :])

            get_RMSE()
            get_nRMSE()
            get_MaxError()
            get_bRMSE()
            get_fRMSE()

            sample_total += outputs.shape[0]

    for metric in res_dict:
        for var in res_dict[metric]:
            res_dict[metric][var] /= sample_total
            res_dict[metric][var] = res_dict[metric][var].item()

    print('-' * 20)
    print(f'metric:')
    for metric in res_dict:
        for var in res_dict[metric]:
            print(f'{metric}\t\t{var}:\t\t{res_dict[metric][var]}')
    # TODO 保存log
    with open(os.path.join(save_dir, f'log_final.json'), "w", encoding="utf-8") as f:
        json.dump(res_dict, f, ensure_ascii=False)


    data = res_dict
    res = []
    for metric in data:
        for var in data[metric]:
            res.append(data[metric][var])

    output_file = os.path.join(save_dir, './exp.csv')
    frmse_df = pd.DataFrame(res)
    frmse_df.to_csv(output_file, index=False, encoding="utf-8", float_format="%.16f")
