import math
import sys
sys.path.append("/here/is/code/M2F-PINN")
import numpy as np
from matplotlib import pyplot as plt
from util_data import utils, utils_data
from util_data.config import cfg
from torch import nn
import torch
import copy
from util_data import score
import os
from tqdm import tqdm
from models.FourierFeatureEmbedding import FeedForward, TrainableFourierFeatureEmbedding


def train(model, accelerator, train_loader, val_loader, optimizer, lr_scheduler, res_path, device, writer, logger, start_epoch,
          rank=0):
    '''Training code'''

    # Loss function
    criterion = nn.L1Loss(reduction='none')

    # training epoch
    epochs = cfg.PG.TRAIN.EPOCHS

    loss_list = []
    best_loss = float('inf')
    epochs_since_last_improvement = 0
    best_model = None

    # Load constants and teleconnection indices here we change utils to utils_data
    aux_constants = utils_data.loadAllConstants(
        device=device)
    upper_weights, surface_weights = aux_constants['variable_weights']

    log_sigma = torch.nn.Parameter(torch.tensor([0.0, 2.0, 2.0]))
    optimizer.add_param_group({'params': [log_sigma]})

    alpha_temp = torch.nn.Parameter(torch.ones(13,721,1440), requires_grad=True) # depth, lat, lon
    optimizer.add_param_group({'params': [alpha_temp]})
    alpha_salt = torch.nn.Parameter(torch.ones(13,721,1440), requires_grad=True) # depth, lat, lon
    optimizer.add_param_group({'params': [alpha_salt]})

    alpha_uv = torch.nn.Parameter(torch.ones(13, 721, 1440), requires_grad=True)  # depth, lat, lon

    optimizer.add_param_group({'params': [alpha_uv]})


    # === Multi-Scale Fourier Feature Embedding Setup ===
    depth_original = torch.tensor([
        0.494025, 2.645669, 5.078224, 7.92956, 11.405, 15.81007,
        21.59882, 29.44473, 40.34405, 55.76429, 77.85385, 109.7293, 130.666
    ], dtype=torch.float32)
    max_depth = depth_original[-1]
    depth_normalized = depth_original / max_depth  # shape: (13,)
    depth_normalized = depth_normalized.to(device)  # 确保在 GPU 上（如果使用 GPU）

    fourier_embedder_low = None
    fourier_embedder_high = None
    ff_layers = None  # FeedForward layers for processing Fourier Features
    batch_mapped_coord_features_static_processed = None
    batch_mapped_coord_features_static_low_raw = None
    batch_mapped_coord_features_static_high_raw = None
    if cfg.FOURIER.FF_ENABLED:  # Check if Fourier Features are enabled in config
        ff_input_dims = cfg.FOURIER.FF_INPUT_DIMS  # e.g., 3 for (depth, lat, lon)

        # Embedder for lower frequencies (higher sigma, captures global structure)
        fourier_embedder_low = TrainableFourierFeatureEmbedding(
            input_dims=ff_input_dims,
            mapping_size=cfg.FOURIER.FF_MAPPING_SIZE_LOW,
            initial_scale=cfg.FOURIER.FF_SCALE_LOW,
            device=device
        ).to(device)
        optimizer.add_param_group({'params': fourier_embedder_low.parameters()})

        # Embedder for higher frequencies (lower sigma, captures details)
        fourier_embedder_high = TrainableFourierFeatureEmbedding(
            input_dims=ff_input_dims,
            mapping_size=cfg.FOURIER.FF_MAPPING_SIZE_HIGH,
            initial_scale=cfg.FOURIER.FF_SCALE_HIGH,
            device=device
        ).to(device)
        optimizer.add_param_group({'params': fourier_embedder_high.parameters()})

        # The input to ff_layers is the raw Fourier features, e.g., mapping_size*2
        # The output of ff_layers will be ff_hidden_dim
        ff_raw_output_dims = cfg.FOURIER.FF_MAPPING_SIZE_LOW * 2  # Assuming low and high mapping_size are same
        if cfg.FOURIER.FF_MAPPING_SIZE_LOW * 2 != cfg.FOURIER.FF_MAPPING_SIZE_HIGH * 2:
            logger.info("Warning: Low and High freq mapping sizes differ. FF layer design might need adjustment.")

        ff_layers = FeedForward(
            input_dimensions=ff_raw_output_dims,  # e.g., 16*2 = 32
            output_dimensions=cfg.FOURIER.FF_HIDDEN_DIM,  # e.g., 32 or 64
            layers_config=[cfg.FOURIER.FF_HIDDEN_DIM],  # e.g., [32] -> one hidden layer of 32
            device=device
        ).to(device)
        optimizer.add_param_group({'params': ff_layers.parameters()})

        # 2. Latitude & Longitude Coordinates (assuming 721 lat, 1440 lon points)
        # Normalize to [-1, 1] as an example
        lat_coords_norm = torch.linspace(-1.0, 1.0, 721, device=device)  # Shape (721)
        lon_coords_norm = torch.linspace(-1.0, 1.0, 1440, device=device)  # Shape (1440)

        # 3. Create Meshgrid
        # Order: Depth, Latitude, Longitude to match (13, 721, 1440)
        D_coords, Y_coords, X_coords = torch.meshgrid(
            depth_normalized,
            lat_coords_norm,
            lon_coords_norm,
            indexing='ij'
        )  # Each has shape (13, 721, 1440)

        # 4. Stack to get coordinate vectors for each point: (13, 721, 1440, 3)
        coords_grid = torch.stack([D_coords, Y_coords, X_coords], dim=-1)

        # 5. Get mapped coordinate features (once)
        # Shape: (13, 721, 1440, ff_embedder.output_dims)
        with torch.no_grad():  # B matrix is fixed, no need for gradients here
            # mapped_coord_features_grid = fourier_embedder(coords_grid)
            mapped_coords_low = fourier_embedder_low(coords_grid)  # (13, 721, 1440, ff_mapping_size_low*2)
            mapped_coords_high = fourier_embedder_high(coords_grid)  # (13, 721, 1440, ff_mapping_size_high*2)

        # 6. Permute to (channels, D, H, W) for concatenation:
        batch_mapped_coord_features_static_low_raw = mapped_coords_low.permute(3, 0, 1, 2)
        batch_mapped_coord_features_static_high_raw = mapped_coords_high.permute(3, 0, 1, 2)
    else:
        logger.info("Fourier Feature Embedding is DISABLED.")
    # === End Fourier Feature Embedding Setup ===

    # Train a single Pangu-Weather model
    for i in range(start_epoch, epochs + 1):
        epoch_loss = 0.0
        data_loss1, pde_tem_loss1, pde_salt_loss1, pde_u_loss1, pde_v_loss1, pde_zos_loss1 = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        dx = 2.0 / (1440 - 1)
        dy = 2.0 / (721 - 1)
        g = 9.81

        for id, train_data in enumerate(tqdm(train_loader, desc="epoch "+str(i)+"-Train")):
            # Load weather data at time t as the input; load weather data at time t+336 as the output
            # Note the data need to be randomly shuffled
            input, _, target, _, periods = train_data
            input, target = input[:, 2:4, :, :, :].to(device), target[:, 2:4, :, :, :].to(device)

            optimizer.zero_grad()
            model.train()

            batch_fourier_features_to_cat = []
            if cfg.FOURIER.FF_ENABLED and \
                    batch_mapped_coord_features_static_low_raw is not None and \
                    batch_mapped_coord_features_static_high_raw is not None and \
                    ff_layers is not None:

                batch_size_current = input.shape[0]

                # Expand static RAW features to match current batch size
                # Low frequency features
                expanded_ff_low_raw = batch_mapped_coord_features_static_low_raw.unsqueeze(0).expand(
                    batch_size_current, -1, -1, -1, -1
                )  # (B, ff_raw_output_dims, D, H, W)
                # Permute for ff_layers: (B, D, H, W, ff_raw_output_dims)
                expanded_ff_low_raw = expanded_ff_low_raw.permute(0, 2, 3, 4, 1)
                processed_ff_low = ff_layers(expanded_ff_low_raw)  # (B, D, H, W, ff_hidden_dim)
                # Permute back for concatenation: (B, ff_hidden_dim, D, H, W)
                batch_fourier_features_to_cat.append(processed_ff_low.permute(0, 4, 1, 2, 3))

                # High frequency features
                expanded_ff_high_raw = batch_mapped_coord_features_static_high_raw.unsqueeze(0).expand(
                    batch_size_current, -1, -1, -1, -1
                )  # (B, ff_raw_output_dims, D, H, W)
                expanded_ff_high_raw = expanded_ff_high_raw.permute(0, 2, 3, 4, 1)
                processed_ff_high = ff_layers(expanded_ff_high_raw)  # (B, D, H, W, ff_hidden_dim)
                batch_fourier_features_to_cat.append(processed_ff_high.permute(0, 4, 1, 2, 3))

                # Concatenate processed Fourier features with original input
                all_fourier_features = torch.cat(batch_fourier_features_to_cat, dim=1)  # Cat along channel
                input_to_model = torch.cat([input, all_fourier_features], dim=1)
            else:
                input_to_model = input  # Original input if FF is disabled

            # Note the input and target need to be normalized (done within the function)
            # Call the model and get the output
            output = model(input_to_model, aux_constants['weather_statistics'],
                           aux_constants['constant_maps'], aux_constants['const_h'])  # (1,5,13,721,1440)

            # Normalize gt to make loss compariable
            target, _ = utils_data.normData(target, None, aux_constants['weather_statistics_last'])
            loss_upper = criterion(output, target)
            weighted_upper_loss = torch.mean(loss_upper)

            pde_u_loss, pde_v_loss= (
                    pde_temperature_two_loss(output,input, dx, dy,
                    torch.nn.functional.softmax(alpha_uv.view(-1), dim=0).view(13, 721, 1440),
                    g,depth_normalized))

            loss = (
                    100*(0.5 * torch.exp(-log_sigma[0]) * (weighted_upper_loss) + log_sigma[0]) +
                    (0.5 * torch.exp(-log_sigma[1]) * (pde_u_loss/100) + log_sigma[1]) +#/1e5
                    (0.5 * torch.exp(-log_sigma[2]) * (pde_v_loss/100)+ log_sigma[2]) #/1e5
                    # +(0.5 * torch.exp(-log_sigma[5]) * (pde_zos_loss) + log_sigma[5])
            )

            if (math.isnan((weighted_upper_loss).item())):
                print('!!!data_problem for 1:'+str(periods))
            if (math.isnan(pde_u_loss.item())):
                print('!!!data_problem for 2pde:'+str(periods))

            accelerator.backward(loss)#, retain_graph=True

            optimizer.step()#yuan daima shi zhege
            epoch_loss += loss.item()
            data_loss1 += (100*(0.5 * torch.exp(-log_sigma[0]) * (weighted_upper_loss) + log_sigma[0])).item()
            pde_u_loss1 += ((0.5 * torch.exp(-log_sigma[1]) * (pde_u_loss/100) + log_sigma[1])).item()
            pde_v_loss1 += ((0.5 * torch.exp(-log_sigma[2]) * (pde_v_loss/100) + log_sigma[2])).item()

        epoch_loss /= len(train_loader)
        data_loss1 /= len(train_loader)
        pde_u_loss1 /= len(train_loader)
        pde_v_loss1 /= len(train_loader)
        if rank == 0:
            logger.info("Epoch {} : {:.3f}".format(i, epoch_loss))
            print(
                f"TRAIN Loss ratio (%): data={100 * data_loss1 / epoch_loss:.1f}%,"# tem={100 * pde_tem_loss1 / epoch_loss:.1f}%, 
                f"u={100 * pde_u_loss1 / epoch_loss:.1f}%, "#salt={100 * pde_salt_loss1 / epoch_loss:.1f}%, 
                f"v={100 * pde_v_loss1 / epoch_loss:.1f}%")

        loss_list.append(epoch_loss)
        lr_scheduler.step()

        model_save_path = os.path.join(res_path, 'models')
        utils.mkdirs(model_save_path)
        writer.add_histogram('alpha_uv', alpha_uv, i)

        # Save the training model
        if i % cfg.PG.TRAIN.SAVE_INTERVAL == 0:
            save_file = {"model": accelerator.unwrap_model(model).state_dict(),
                         "optimizer": optimizer.state_dict(),
                         "lr_scheduler": lr_scheduler.state_dict(),
                         "alpha_loss_raw": log_sigma.data,
                         "alpha_uv": alpha_uv.data,
                         "epoch": i}

            if cfg.FOURIER.FF_ENABLED:
                if fourier_embedder_low is not None:
                    save_file['fourier_embedder_low'] = fourier_embedder_low.state_dict()
                if fourier_embedder_high is not None:
                    save_file['fourier_embedder_high'] = fourier_embedder_high.state_dict()
                if ff_layers is not None:
                    save_file['ff_layers'] = ff_layers.state_dict()
            torch.save(save_file, os.path.join(model_save_path, 'train_{}.pth'.format(i)))

        # Begin to validate
        if i % cfg.PG.VAL.INTERVAL == 0:
            with (torch.no_grad()):
                model.eval()
                val_loss = 0.0
                val_data_loss2, pde_tem_loss_val2,pde_salt_loss_val2, pde_u_loss_val2, pde_v_loss_val2, pde_zos_loss_val2 = 0.0, 0.0,0.0, 0.0,0.0, 0.0
                for id, val_data in enumerate(tqdm(val_loader, desc="epoch "+str(i)+"-Validation"), 0):
                    input_val, _, target_val, _, periods_val = val_data
                    input_val_raw = input_val
                    input_val, target_val = input_val[:, 2:4, :, :, :].to(device),target_val[:, 2:4, :, :, :].to(device)

                    batch_fourier_features_to_cat_val = []
                    if cfg.FOURIER.FF_ENABLED and \
                            batch_mapped_coord_features_static_low_raw is not None and \
                            batch_mapped_coord_features_static_high_raw is not None and \
                            ff_layers is not None:
                        batch_size_current_val = input_val.shape[0]
                        # Expand static RAW features to match current batch size
                        # Low frequency features
                        expanded_ff_low_raw_val = batch_mapped_coord_features_static_low_raw.unsqueeze(0).expand(
                            batch_size_current_val, -1, -1, -1, -1
                        )  # (B, ff_raw_output_dims, D, H, W)
                        # Permute for ff_layers: (B, D, H, W, ff_raw_output_dims)
                        expanded_ff_low_raw_val = expanded_ff_low_raw_val.permute(0, 2, 3, 4, 1)
                        processed_ff_low_val = ff_layers(expanded_ff_low_raw_val)  # (B, D, H, W, ff_hidden_dim)
                        # Permute back for concatenation: (B, ff_hidden_dim, D, H, W)
                        batch_fourier_features_to_cat_val.append(processed_ff_low_val.permute(0, 4, 1, 2, 3))

                        # High frequency features
                        expanded_ff_high_raw_val = batch_mapped_coord_features_static_high_raw.unsqueeze(0).expand(
                            batch_size_current_val, -1, -1, -1, -1
                        )  # (B, ff_raw_output_dims, D, H, W)
                        expanded_ff_high_raw_val = expanded_ff_high_raw_val.permute(0, 2, 3, 4, 1)
                        processed_ff_high_val = ff_layers(expanded_ff_high_raw_val)  # (B, D, H, W, ff_hidden_dim)
                        batch_fourier_features_to_cat_val.append(processed_ff_high_val.permute(0, 4, 1, 2, 3))

                        # Concatenate processed Fourier features with original input
                        all_fourier_features_val = torch.cat(batch_fourier_features_to_cat_val, dim=1)  # Cat along channel
                        input_to_model_val = torch.cat([input_val, all_fourier_features_val], dim=1)
                    else:
                        input_to_model_val = input_val  # Original input if FF is disabled


                    # Inference
                    output_val = model(input_to_model_val, aux_constants['weather_statistics'],
                                                           aux_constants['constant_maps'], aux_constants['const_h'])
                    # Noralize the gt to make the loss compariable
                    target_val, _ = utils_data.normData(target_val, None,
                                                              aux_constants['weather_statistics_last'])

                    val_loss_upper = criterion(output_val, target_val)
                    weighted_val_loss_upper = torch.mean(val_loss_upper)

                    pde_u_loss_val, pde_v_loss_val = pde_temperature_two_loss(
                        output_val, input_val,dx, dy,
                        torch.nn.functional.softmax(alpha_uv.view(-1), dim=0).view(13, 721, 1440),
                        g, depth_normalized
                    )
                    loss = (
                            100*(0.5 * torch.exp(-log_sigma[0]) * (weighted_val_loss_upper) + log_sigma[0]) +
                            (0.5 * torch.exp(-log_sigma[1]) *( pde_u_loss_val/100) + log_sigma[1]) +#/1e5
                            (0.5 * torch.exp(-log_sigma[2]) * (pde_v_loss_val/100) + log_sigma[2])#
                    )

                    val_loss += loss.item()
                    val_data_loss2 += (100*(0.5 * torch.exp(-log_sigma[0]) * (weighted_val_loss_upper) + log_sigma[0])).item()
                    pde_u_loss_val2 += ((0.5 * torch.exp(-log_sigma[1]) * (pde_u_loss_val/100) + log_sigma[1])).item()#/1e5
                    pde_v_loss_val2 += ((0.5 * torch.exp(-log_sigma[2]) * (pde_v_loss_val/100) + log_sigma[2])).item()#/1e5

                val_loss /= len(val_loader)
                val_data_loss2 /= len(val_loader)
                pde_u_loss_val2 /= len(val_loader)
                pde_v_loss_val2 /= len(val_loader)
                writer.add_scalars('Loss',
                                   {'train': epoch_loss,
                                    'val': val_loss},
                                   i)
                writer.add_scalars('Data_Loss',
                                   {'train': data_loss1,
                                    'val': val_data_loss2},
                                   i)
                writer.add_scalars('U_Loss',
                                   {'train': pde_u_loss1,
                                    'val': pde_u_loss_val2},
                                   i)
                writer.add_scalars('V_Loss',
                                   {'train': pde_v_loss1,
                                    'val': pde_v_loss_val2},
                                   i)
                writer.add_scalars('Alpha_Loss/Val',
                                   {'alpha_loss_0': log_sigma[0].item(),
                                    'alpha_loss_1': log_sigma[1].item(),
                                    'alpha_loss_2': log_sigma[2].item()},
                                   i)
                logger.info("Validate at Epoch {} : {:.3f}".format(i, val_loss))
                print(
                    f"VAL Loss ratio (%): data={100 * val_data_loss2 / val_loss:.1f}%,"
                    f"u={100 * pde_u_loss_val2 / val_loss:.1f}%, "
                    f"v={100 * pde_v_loss_val2 / val_loss:.1f}%")
                # Visualize the training process
                png_path = os.path.join(res_path, "png_training")
                utils.mkdirs(png_path)

                # Early stopping
                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = copy.deepcopy(model)
                    # Save the best model
                    torch.save(best_model, os.path.join(model_save_path, 'best_model.pth'))
                    logger.info(
                        f"current best model is saved at {i} epoch.")
                    epochs_since_last_improvement = 0
                else:
                    epochs_since_last_improvement += 1
                    if epochs_since_last_improvement >= 5:
                        logger.info(
                            f"No improvement in validation loss for {epochs_since_last_improvement} epochs, terminating training.")
                        break

    return best_model


def test(test_loader, model, device, res_path, alpha_uv):
    # set up empty dics for rmses and anormaly correlation coefficients
    rmse_upper_z, rmse_upper_q, rmse_upper_t, rmse_upper_u, rmse_upper_v = dict(), dict(), dict(), dict(), dict()
    rmse_surface = dict()

    acc_upper_z, acc_upper_q, acc_upper_t, acc_upper_u, acc_upper_v = dict(), dict(), dict(), dict(), dict()
    acc_surface = dict()
    pic_upper_temp, pic_upper_salt, pic_upper_u, pic_upper_v, pic_upper_zos = dict(), dict(), dict(), dict(), dict()
    pic_gloabl_t, pic_gloabl_s, pic_gloabl_u, pic_gloabl_v, pic_gloabl_z = dict(), dict(), dict(), dict(), dict()

    # Load all statistics and constants
    aux_constants = utils_data.loadAllConstants(device=device)
    
    fourier_embedder_low = None
    fourier_embedder_high = None
    ff_layers = None
    batch_mapped_coord_features_static_low_raw = None
    batch_mapped_coord_features_static_high_raw = None

    if cfg.FOURIER.FF_ENABLED:
        ff_input_dims = cfg.FOURIER.FF_INPUT_DIMS

        fourier_embedder_low = TrainableFourierFeatureEmbedding(
            input_dims=ff_input_dims,
            mapping_size=cfg.FOURIER.FF_MAPPING_SIZE_LOW,
            initial_scale=cfg.FOURIER.FF_SCALE_LOW,
            device=device
        ).to(device)

        fourier_embedder_high = TrainableFourierFeatureEmbedding(
            input_dims=ff_input_dims,
            mapping_size=cfg.FOURIER.FF_MAPPING_SIZE_HIGH,
            initial_scale=cfg.FOURIER.FF_SCALE_HIGH,
            device=device
        ).to(device)

        ff_raw_output_dims = cfg.FOURIER.FF_MAPPING_SIZE_LOW * 2
        ff_layers = FeedForward(
            input_dimensions=ff_raw_output_dims,
            output_dimensions=cfg.FOURIER.FF_HIDDEN_DIM,
            layers_config=[cfg.FOURIER.FF_HIDDEN_DIM],
            device=device
        ).to(device)

        checkpoint = torch.load(cfg.PG.BENCHMARK.var2_change_structure_noFF, map_location=device)
        if 'fourier_embedder_low' in checkpoint:
            fourier_embedder_low.load_state_dict(checkpoint['fourier_embedder_low'])
            fourier_embedder_high.load_state_dict(checkpoint['fourier_embedder_high'])
            ff_layers.load_state_dict(checkpoint['ff_layers'])
        else:
            print("警告：在检查点文件中未找到傅里叶网络权重，将使用随机权重。")

        lat_coords_norm = torch.linspace(-1.0, 1.0, 721, device=device)
        lon_coords_norm = torch.linspace(-1.0, 1.0, 1440, device=device)

        depth_original = torch.tensor([
            0.494025, 2.645669, 5.078224, 7.92956, 11.405, 15.81007,
            21.59882, 29.44473, 40.34405, 55.76429, 77.85385, 109.7293, 130.666
        ], dtype=torch.float32)
        max_depth = depth_original[-1]
        depth_normalized = depth_original / max_depth  # shape: (13,)
        depth_normalized = depth_normalized.to(device)

        D_coords, Y_coords, X_coords = torch.meshgrid(
            depth_normalized, lat_coords_norm, lon_coords_norm, indexing='ij'
        )
        coords_grid = torch.stack([D_coords, Y_coords, X_coords], dim=-1)
    else:
        print("测试模式：未启用傅里叶特征嵌入。")
    # === 傅里叶特征嵌入设置结束 ===


    batch_id = 0
    for id, data in enumerate(tqdm(test_loader, desc="epoch "+"-Test"), 0):
        # Store initial input for different models
        print(f"predict on {id}")
        input_test, _, target_test, _, periods_test = data
        input_test, target_test = input_test[:, 2:4, :, :, :].to(device), target_test[:, 2:4, :, :, :].to(device)
        model.eval()

        fourier_embedder_low.eval()
        fourier_embedder_high.eval()
        ff_layers.eval()

        with torch.no_grad():
            if cfg.FOURIER.FF_ENABLED:
                mapped_coords_low_raw = fourier_embedder_low(coords_grid)  # (D,H,W, low_dim*2)
                mapped_coords_high_raw = fourier_embedder_high(coords_grid)  # (D,H,W, high_dim*2)

                # Permute for ff_layers: (D, H, W, ff_raw_output_dims) -> (D, H, W, ff_hidden_dim)
                processed_ff_low = ff_layers(mapped_coords_low_raw)
                processed_ff_high = ff_layers(mapped_coords_high_raw)

                # Permute back for concatenation: (D, H, W, ff_hidden_dim) -> (ff_hidden_dim, D, H, W)
                processed_ff_low_ch_first = processed_ff_low.permute(3, 0, 1, 2)
                processed_ff_high_ch_first = processed_ff_high.permute(3, 0, 1, 2)

                batch_size_current = input_test.shape[0]
                expanded_low = processed_ff_low_ch_first.unsqueeze(0).expand(batch_size_current, -1, -1, -1, -1)
                expanded_high = processed_ff_high_ch_first.unsqueeze(0).expand(batch_size_current, -1, -1, -1, -1)

                all_fourier_features = torch.cat([expanded_low, expanded_high], dim=1)
                input_to_model_test = torch.cat([input_test, all_fourier_features], dim=1)
            else:
                input_to_model_test = input_test

        # Inference
        output_test = model(input_to_model_test, aux_constants['weather_statistics'],
                                                 aux_constants['constant_maps'], aux_constants['const_h'])
        # Transfer to the output to the original data range
        output_test, _ = utils_data.normBackData(output_test, None,
                                                        aux_constants['weather_statistics_last'])

        target_time = periods_test[1][batch_id]

        # Visualize
        png_path = os.path.join(res_path, "png_PICuo")
        utils.mkdirs(png_path)

        # utils.visuailze(output_test.detach().cpu().squeeze(),
        #                         target_test.detach().cpu().squeeze(),
        #                         input_test.detach().cpu().squeeze(),
        #                         var='vo',#'t'
        #                         z = 2,
        #                         step=target_time,
        #                         path=png_path)


        # Compute test scores
        # rmse
        input_test = input_test.squeeze()
        output_test = output_test.squeeze()
        target_test = target_test.squeeze()
        # output_surface_test = output_surface_test.squeeze()
        # target_surface_test = target_surface_test.squeeze()
        rmse_upper_z[target_time] = score.weighted_rmse_torch_channels(output_test[0],
                                                                       target_test[0]).detach().cpu().numpy()
        rmse_upper_q[target_time] = score.weighted_rmse_torch_channels(output_test[1],
                                                                       target_test[1]).detach().cpu().numpy()

        # acc
        surface_mean, _, upper_mean, _ = aux_constants['weather_statistics_last']
        output_test_anomaly = output_test - upper_mean[:,2:4,:,:,:].squeeze(0)
        target_test_anomaly = target_test - upper_mean[:,2:4,:,:,:].squeeze(0)

        acc_upper_z[target_time] = score.weighted_acc_torch_channels(output_test_anomaly[0],
                                                                     target_test_anomaly[0]).detach().cpu().numpy()
        acc_upper_q[target_time] = score.weighted_acc_torch_channels(output_test_anomaly[1],
                                                                     target_test_anomaly[1]).detach().cpu().numpy()
         # physical inconsistency(PIC)
        (pic_upper_u[target_time], pic_upper_v[target_time],
         pic_gloabl_u[target_time],pic_gloabl_v[target_time]) = pic_two_loss(output_test[0],input_test[0], target_test[0],
                                                      output_test[1],input_test[1], target_test[1],
                                                      alpha_uv)

        visuailze_pic(pic_gloabl_u[target_time], pic_gloabl_v[target_time],
                      var='uo', z=0, step=target_time, path=png_path)

    # Save rmses to csv
    csv_path = os.path.join(res_path, "csv")
    utils.mkdirs(csv_path)
    utils.save_errorScores(csv_path, rmse_upper_z, rmse_upper_q, rmse_upper_t, rmse_upper_u, rmse_upper_v, rmse_surface,
                     "rmse")
    utils.save_errorScores(csv_path, acc_upper_z, acc_upper_q, acc_upper_t, acc_upper_u, acc_upper_v, acc_surface, "acc")
    utils.save_errorScores(csv_path, pic_upper_temp, pic_upper_salt, pic_upper_u, pic_upper_v, pic_upper_zos, acc_surface, "pic")

def pic_two_loss(output_u, input_u, target_u,
                  output_v, input_v, target_v,
                  alpha_uv):
    kappa_uvpass = alpha_uv
    kappa_uvpass = (kappa_uvpass.to(device=output_u.device))

    mae_creation = torch.nn.L1Loss(reduction='none')
    g = 9.81
    dx = 2.0 / (1440 - 1)
    dy = 2.0 / (721 - 1)
    dt = 1.0
    depth_original = torch.tensor([
        0.494025, 2.645669, 5.078224, 7.92956, 11.405, 15.81007,
        21.59882, 29.44473, 40.34405, 55.76429, 77.85385, 109.7293, 130.666
    ], dtype=torch.float32)
    max_depth = depth_original[-1]
    depth_normalized = depth_original / max_depth  # shape: (13,)
    depth_normalized = depth_normalized.to(output_u.device)

    def compute_gradients(field, dx, dy, depth_normalized):
        grad_z = torch.gradient(field, spacing=(depth_normalized,), dim=-3)[0]
        grad_y = torch.gradient(field, spacing=dy, dim=-2)[0]
        grad_x = torch.gradient(field, spacing=dx, dim=-1)[0]
        grad_x = torch.clamp(grad_x, min=-1e5, max=1e5)
        grad_y = torch.clamp(grad_y, min=-1e5, max=1e5)
        grad_z = torch.clamp(grad_z, min=-1e5, max=1e5)
        return grad_z, grad_y, grad_x
    def compute_laplacian(grad_z, grad_y, grad_x):
        grad_z2 = torch.gradient(grad_z, spacing=(depth_normalized,), dim=(-3,))[0]
        grad_y2 = torch.gradient(grad_y, spacing=(dy,), dim=(-2,))[0]
        grad_x2 = torch.gradient(grad_x, spacing=(dx,), dim=(-1,))[0]
        return grad_x2 + grad_y2 + grad_z2

    with torch.no_grad():
        du_dt_pred = (output_u - input_u) / dt
        du_dt_real = (target_u - input_u) / dt
        dv_dt_pred = (output_v - input_v) / dt
        dv_dt_real = (target_v - input_v) / dt

    grad_uz_pred, grad_uy_pred, grad_ux_pred = compute_gradients(output_u, dx, dy, depth_normalized)
    grad_uz_real, grad_uy_real, grad_ux_real = compute_gradients(target_u, dx, dy, depth_normalized)
    grad_vz_pred, grad_vy_pred, grad_vx_pred = compute_gradients(output_v, dx, dy, depth_normalized)
    grad_vz_real, grad_vy_real, grad_vx_real = compute_gradients(target_v, dx, dy, depth_normalized)

    lap_u_pred = compute_laplacian(grad_uz_pred, grad_uy_pred, grad_ux_pred)
    lap_u_real = compute_laplacian(grad_uz_real, grad_uy_real, grad_ux_real)
    lap_v_pred = compute_laplacian(grad_vz_pred, grad_vy_pred, grad_vx_pred)
    lap_v_real = compute_laplacian(grad_vz_real, grad_vy_real, grad_vx_real)

    # PDE residuals
    u_res_pred = du_dt_pred + output_u * grad_ux_pred + output_v * grad_uy_pred - kappa_uvpass * lap_u_pred
    u_res_real = du_dt_real + output_u * grad_ux_real + output_v * grad_uy_real - kappa_uvpass * lap_u_real
    v_res_pred = dv_dt_pred + output_u * grad_vx_pred + output_v * grad_vy_pred - kappa_uvpass * lap_v_pred
    v_res_real = dv_dt_real + output_u * grad_vx_real + output_v * grad_vy_real - kappa_uvpass * lap_v_real

    squared_error3 = mae_creation(u_res_pred, u_res_real)
    u_mae_per_time = squared_error3.mean(dim=(1, 2))
    squared_error4 = mae_creation(v_res_pred, v_res_real)
    v_mae_per_time = squared_error4.mean(dim=(1, 2))

    return ( u_mae_per_time.detach().cpu().numpy(),
            v_mae_per_time.detach().cpu().numpy(),
             squared_error3.detach().cpu().numpy(),
             squared_error4.detach().cpu().numpy())


def pde_temperature_two_loss(output, input, dx, dy, kappa_uv,g, depth_normalized):
    kappa_uvpass = kappa_uv
    kappa_uvpass = (kappa_uvpass.to(device=output.device))
    dt = 1.0
    input_u = input[:, 0, :, :, :]
    output_u = output[:, 0, :, :, :]
    input_u = (input_u - input_u.mean()) / (input_u.std() + 1e-8)
    output_u = (output_u - output_u.mean()) / (output_u.std() + 1e-8)
    input_v = input[:, 1, :, :, :]
    output_v = output[:, 1, :, :, :]
    input_v = (input_v - input_v.mean()) / (input_v.std() + 1e-8)
    output_v = (output_v - output_v.mean()) / (output_v.std() + 1e-8)

    def compute_gradients(field, dx, dy, depth_normalized):
        grad_z = torch.gradient(field, spacing=(depth_normalized,), dim=-3)[0]
        grad_y = torch.gradient(field, spacing=dy, dim=-2)[0]
        grad_x = torch.gradient(field, spacing=dx, dim=-1)[0]
        grad_x = torch.clamp(grad_x, min=-1e5, max=1e5)
        grad_y = torch.clamp(grad_y, min=-1e5, max=1e5)
        grad_z = torch.clamp(grad_z, min=-1e5, max=1e5)
        return grad_z, grad_y, grad_x

    with torch.no_grad():
        du_dt = (output_u - input_u) / dt
        dv_dt = (output_v - input_v) / dt

    grad_uz, grad_uy, grad_ux = compute_gradients(output_u, dx, dy, depth_normalized)
    grad_vz, grad_vy, grad_vx = compute_gradients(output_v, dx, dy, depth_normalized)

    def compute_laplacian(grad_z, grad_y, grad_x):
        grad_z2 = torch.gradient(grad_z, spacing=(depth_normalized,), dim=(-3,))[0]
        grad_y2 = torch.gradient(grad_y, spacing=(dy,), dim=(-2,))[0]
        grad_x2 = torch.gradient(grad_x, spacing=(dx,), dim=(-1,))[0]
        return grad_x2 + grad_y2 + grad_z2

    # Laplacian
    lap_u = compute_laplacian(grad_uz, grad_uy, grad_ux)
    lap_v = compute_laplacian(grad_vz, grad_vy, grad_vx)

    # PDE residuals
    u_res = du_dt + output_u * grad_ux + output_v * grad_uy - kappa_uvpass * lap_u
    v_res = dv_dt + output_u * grad_vx + output_v * grad_vy - kappa_uvpass * lap_v

    # Loss
    loss_u = torch.mean(u_res**2)
    loss_v = torch.mean(v_res**2)

    return loss_u, loss_v


def visuailze_pic(pic_global_u,pic_global_v, var, z, step, path):
    variables = cfg.ERA5_UPPER_VARIABLES
    var = variables.index(var)
    fig = plt.figure(figsize=(4, 2))
    # ax1 = fig.add_subplot(143)
    ax1 = fig.add_subplot(111)

    plot1 = ax1.imshow(pic_global_u[z,:,:][::-1], cmap="RdBu",vmin=-np.mean(pic_global_u[z, :, :]), vmax=np.mean(pic_global_u[z, :, :]))  # , levels = levels, extend = 'min')
    plt.colorbar(plot1, ax=ax1, fraction=0.05, pad=0.05)
    ax1.title.set_text('PIC of u')
    plt.savefig(fname=os.path.join(path, 'PIC_{}_{}_Z{}'.format(step, variables[var], z)))
    
    plt.clf()
    plt.cla()
    plt.close("all")


if __name__ == "__main__":
    pass
