import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from functools import partial
from einops import rearrange
import utils.misc as utils
from utils.metrics import WRMSE
from utils.builder import get_optimizer, get_lr_scheduler
#import modules
from utils.distributions import DiagonalGaussianDistribution
from models.vaeformer import VAEformer
from models.diffusion.score import ScoreSiT
from models.diffusion.score import VPSDE
from models.diffusion.score import GuidanceSampling
import pandas as pd
import h5py

class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        if isinstance(module, nn.DataParallel):
            inner_module = module.module
            module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
            module_copy.load_state_dict(inner_module.state_dict())
            module_copy = nn.DataParallel(module_copy)
        else:
            module_copy = type(module)(module.config).to(module.config.device)
            module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def cal_da_score(pool,truth,data_std):

    rmses = []
    for sample in pool:
        rmse = WRMSE(sample,truth,data_std)
        rmses.append(rmse.to(sample.device))
    #m channel: u10 v10 t2m z500 u500 v500 t500 
    var_s = 69
    N = len(pool)
    Rnm = torch.zeros((N,69)).to(sample.device)

    for n in range(N):
        for m in range(var_s):
            #cal Rnm (1/N-1)*\sum_j \sigmoid(r_m(xn)-r_m(xj)) 
            for j in range(N):
                if j==n:
                    continue
                #rmse[n] is the rmse of n-th sample
                Rnm[n][m] += torch.sigmoid(-rmses[n][m]+rmses[j][m])/(N-1)
    Rn = torch.mean(Rnm,axis = 1)
    if utils.get_rank()==0:
        for n in range(N):
            print("da score:",n,Rn[n])
    return Rn

def cal_fore_score(pool, h_42, h_48_era5 ,data_std,args):
    
    rmses = []
    for n in range(len(pool)):
        da = pool[n]
        inp_data = torch.cat([h_42, da], dim=1).cpu().numpy()
        for _ in range(8):
            predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:da.shape[1]]
            inp_data = np.concatenate([inp_data[:,-da.shape[1]:], predict_data], axis=1)
        da_fore_ =  torch.from_numpy(predict_data).to(h_48_era5.device)
        rmse = WRMSE(da_fore_,h_48_era5,data_std)
        if utils.get_rank()==0:
            print("forecast:",n,rmse[0].item(),rmse[1].item(),rmse[2].item(),rmse[3].item(),rmse[11].item())
        rmses.append(rmse.to(da.device))
    
    #m channel: u10 v10 t2m z500 u500 v500 t500 
    var_s = 69
    N = len(pool)
    Rnm = torch.zeros((N,69)).to(da.device)

    for n in range(N):
        for m in range(var_s):
            #cal Rnm (1/N-1)*\sum_j \sigmoid(r_m(xn)-r_m(xj)) 
            for j in range(N):
                if j==n:
                    continue
                #rmse[n] is the rmse of n-th sample
                Rnm[n][m] += torch.sigmoid(-rmses[n][m]+rmses[j][m])/(N-1)
    Rn = torch.mean(Rnm,axis = 1)
    if utils.get_rank()==0:
        for n in range(N):
            print("forecast score:",n,Rn[n])
    return Rn

def cal_fore_score_(da, h_42, h_48_era5,args):
    
    
    inp_data = torch.cat([h_42, da], dim=1).cpu().numpy()
    for _ in range(8):
        predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:da.shape[1]]
        inp_data = np.concatenate([inp_data[:,-da.shape[1]:], predict_data], axis=1)
    da_fore_ =  torch.from_numpy(predict_data).to(h_48_era5.device)
    return da_fore_
    
    
def cal_geo(field,data_std,data_mean):
    lat = np.linspace(-90, 90, 128) 
    lon = np.linspace(-180, 180, 256) 
    g = 9.81  # m/s²
    omega = 7.292*1e-5  # rad/s
    f = 2 * omega * np.sin(np.deg2rad(lat))  
    
    dlat = np.diff(lat).mean()  
    dy = 111319 * dlat  
    dlon = np.diff(lon).mean()  
    dx = 111319 * np.cos(np.deg2rad(lat)) * dlon  

    lat_2d, lon_2d = np.meshgrid(lat, lon, indexing='ij')  
    mask = (np.abs(lat_2d) >= 30) & (np.abs(lat_2d) <= 60)

    field = (field*data_std+data_mean).detach().cpu().numpy()

    Phi = field[0,11,:,:]
    z = Phi/9.81
    dz_dy = np.gradient(z, dy, axis=0)   # 纬度梯度
    dz_dx = np.zeros((128,256))

    for i in range(128):
        dz_dx[i] = -np.gradient(z[i], dx[i])
    
    u_g = - (g / f[:, np.newaxis]) * dz_dy 
    v_g = (g / f[:, np.newaxis]) * dz_dx
    u_g_t,v_g_t = field[0,4+13*2+8-1,:,:],field[0,4+13*3+8-1,:,:]

    u_g = np.where(mask, u_g, 0)
    v_g = np.where(mask, v_g, 0)
    u_g_t = np.where(mask, u_g_t, 0)
    v_g_t = np.where(mask, v_g_t, 0)
    err_ug = u_g - u_g_t
    err_vg = v_g - v_g_t

    ratio = np.mean(abs(err_ug))/np.mean(abs(u_g_t)) + np.mean(abs(err_vg))/np.mean(abs(v_g_t))
    #if utils.get_rank()==0:
    #    print("geo:",ratio/2)
    return ratio/2

def cal_phys_score(pool,data_std,data_mean):

    lat = np.linspace(-90, 90, 128) 
    lon = np.linspace(-180, 180, 256) 
    g = 9.81  # m/s²
    omega = 7.292*1e-5  # rad/s
    f = 2 * omega * np.sin(np.deg2rad(lat))  
    
    dlat = np.diff(lat).mean()  
    dy = 111319 * dlat  
    dlon = np.diff(lon).mean()  
    dx = 111319 * np.cos(np.deg2rad(lat)) * dlon  

    lat_2d, lon_2d = np.meshgrid(lat, lon, indexing='ij')  
    mask = (np.abs(lat_2d) >= 30) & (np.abs(lat_2d) <= 60)

    Rn = np.zeros(len(pool))
    for j in range(len(pool)):
        field = (pool[j]*data_std+data_mean).detach().cpu().numpy()
        Phi = field[0,11,:,:]
        z = Phi/9.81
        dz_dy = np.gradient(z, dy, axis=0)   # 纬度梯度
        dz_dx = np.zeros((128,256))

        for i in range(128):
            dz_dx[i] = -np.gradient(z[i], dx[i])
    
        u_g = - (g / f[:, np.newaxis]) * dz_dy 
        v_g = (g / f[:, np.newaxis]) * dz_dx
        u_g_t,v_g_t = field[0,4+13*2+8-1,:,:],field[0,4+13*3+8-1,:,:]

        u_g = np.where(mask, u_g, 0)
        v_g = np.where(mask, v_g, 0)
        u_g_t = np.where(mask, u_g_t, 0)
        v_g_t = np.where(mask, v_g_t, 0)
        err_ug = u_g - u_g_t
        err_vg = v_g - v_g_t

        ratio = np.mean(abs(err_ug))/np.mean(abs(u_g_t)) + np.mean(abs(err_vg))/np.mean(abs(v_g_t))
        Rn[j] = sigmoid(-10*ratio/2)
        if utils.get_rank()==0:
            print("phys score:",j,Rn[j])
    return Rn

def select(scores):
    N = len(scores)
    ranked = np.zeros((len(scores),7))

    ranked[:,1:4] = scores

    for i in range(N):
        ranked[i][0] = i
    
    ranked = ranked[(-ranked[:, 1]).argsort()]
    for i in range(N):
        ranked[i][4] = i
    
    ranked = ranked[(-ranked[:, 2]).argsort()]
    for i in range(N):
        ranked[i][5] = i
    
    ranked = ranked[(-ranked[:, 3]).argsort()]
    for i in range(N):
        ranked[i][6] = i
    len_ = 10 #int(N//3)-1
    wins_can = ranked[0:len_,:]
    wins_index = []
    for i in range(len_):
        if wins_can[i][4]<len_ and wins_can[i][5]<len_:
            wins_index.append(int(wins_can[i][0]))

    loses_can = ranked[-len_:,:]
    loses_index = []
    for i in range(len_):
        if loses_can[i][4]>N-len_-1 and loses_can[i][5]>N-len_-1:
            loses_index.append(int(loses_can[i][0]))
    return wins_index,loses_index

def singleselect(scores):
    N = len(scores)
    ranked = np.zeros((len(scores),7))

    ranked[:,1:4] = scores

    #index
    for i in range(N):
        ranked[i][0] = i

    #DA rank
    
    ranked = ranked[(-ranked[:, 1]).argsort()]
    for i in range(N):
        ranked[i][4] = i
    DA_wins_index= [int(ranked[0][0]),int(ranked[1][0]),int(ranked[2][0]),int(ranked[3][0]),int(ranked[4][0])]
    DA_loses_index = [int(ranked[27][0]),int(ranked[28][0]),int(ranked[29][0]),int(ranked[30][0]),int(ranked[31][0])]
    #DA win/loss index

    #Forecast rank
    ranked = ranked[(-ranked[:, 2]).argsort()]
    for i in range(N):
        ranked[i][5] = i
    Fore_wins_index = [int(ranked[0][0]),int(ranked[1][0]),int(ranked[2][0]),int(ranked[3][0]),int(ranked[4][0])]
    Fore_loses_index = [int(ranked[27][0]),int(ranked[28][0]),int(ranked[29][0]),int(ranked[30][0]),int(ranked[31][0])] #Fore win/loss index

    #physics rank
    ranked = ranked[(-ranked[:, 3]).argsort()]
    for i in range(N):
        ranked[i][6] = i
    Phys_wins_index = [int(ranked[0][0]),int(ranked[1][0]),int(ranked[2][0]),int(ranked[3][0]),int(ranked[4][0])]
    Phys_loses_index = [int(ranked[27][0]),int(ranked[28][0]),int(ranked[29][0]),int(ranked[30][0]),int(ranked[31][0])] #physical win/loss index

    return DA_wins_index,DA_loses_index, Fore_wins_index,Fore_loses_index, Phys_wins_index,Phys_loses_index

def doubleselect(scores):
    N = len(scores)
    ranked = np.zeros((len(scores),7))

    ranked[:,1:4] = scores

    #index
    for i in range(N):
        ranked[i][0] = i

    #DA and Fore
    ranked = ranked[(-ranked[:, 1]).argsort()]
    for i in range(N):
        ranked[i][4] = i
    ranked = ranked[(-ranked[:, 2]).argsort()]
    for i in range(N):
        ranked[i][5] = i

    len_ = 10 #int(N//3)-1
    wins_can = ranked[0:len_,:]
    da_fore_wins_index = []
    for i in range(len_):
        if wins_can[i][4]<len_:
            da_fore_wins_index.append(int(wins_can[i][0]))

    loses_can = ranked[-len_:,:]
    da_fore_loses_index = []
    for i in range(len_):
        if loses_can[i][4]>N-len_-1:
            da_fore_loses_index.append(int(loses_can[i][0]))
    

    #DA and Phys
    ranked = ranked[(-ranked[:, 1]).argsort()]
    for i in range(N):
        ranked[i][4] = i
    ranked = ranked[(-ranked[:, 3]).argsort()]
    for i in range(N):
        ranked[i][6] = i

    len_ = 10 #int(N//3)-1
    wins_can = ranked[0:len_,:]
    da_phys_wins_index = []
    for i in range(len_):
        if wins_can[i][4]<len_:
            da_phys_wins_index.append(int(wins_can[i][0]))

    loses_can = ranked[-len_:,:]
    da_phys_loses_index = []
    for i in range(len_):
        if loses_can[i][4]>N-len_-1:
            da_phys_loses_index.append(int(loses_can[i][0]))
    

    #Fore and Phys
    ranked = ranked[(-ranked[:, 2]).argsort()]
    for i in range(N):
        ranked[i][5] = i
    ranked = ranked[(-ranked[:, 3]).argsort()]
    for i in range(N):
        ranked[i][6] = i

    len_ = 10 #int(N//3)-1
    wins_can = ranked[0:len_,:]
    fore_phys_wins_index = []
    for i in range(len_):
        if wins_can[i][5]<len_:
            fore_phys_wins_index.append(int(wins_can[i][0]))

    loses_can = ranked[-len_:,:]
    fore_phys_loses_index = []
    for i in range(len_):
        if loses_can[i][5]>N-len_-1:
            fore_phys_loses_index.append(int(loses_can[i][0]))

    return da_fore_wins_index,da_fore_loses_index, da_phys_wins_index,da_phys_loses_index, fore_phys_wins_index,fore_phys_loses_index
    
class diffDA_model(nn.Module):
    def __init__(self,):
        super().__init__()

        data_size = (69,32,64)
        diff_backbone = ScoreSiT(channels=69, input_size=(32,64), patch_size=(2,2), depth=28, hidden_size=1152, num_heads=16)
        self.diff = VPSDE(diff_backbone,data_size)

        
        self.AE = VAEformer(model_version=69)
        
        self.AE.load_state_dict(torch.load("/path/to/VAE_model"))
        for param in self.AE.parameters():
            param.requires_grad = False
           
        
    def forward(self, bkg_latent, ana_latent):
        
        noise_prediction, noise = self.diff.loss(ana_latent,bkg_latent)

        return noise_prediction, noise
    
    def sample(self,bkg_latent):
        #bkg_latent = self.AE.encode(bkg_field).mode()
        ana_latent = self.diff.sample((1), bkg_latent)
        #ana_latent = ana_latent*self.latent_std+self.latent_mean

        #ana_field = self.AE.decode(ana_latent)
        return ana_latent #,self.AE.decode(ana_latent)
    
    def sample_laop(self,A,bkg_latent,guidance):
        #bkg_latent = self.AE.encode(bkg_field).mode()
        ana_latent = self.diff.sample_laop(A,(1), bkg_latent,guidance)
        #ana_latent = ana_latent*self.latent_std+self.latent_mean

        #ana_field = self.AE.decode(ana_latent)
        return ana_latent #, noguide_latent#,self.AE.decode(ana_latent) 
    
    def sample_hard(self,A,bkg_latent):
        #bkg_latent = self.AE.encode(bkg_field).mode()
        ana_latent = self.diff.sample_hard_mask(A,(1), bkg_latent)
        #ana_latent = ana_latent*self.latent_std+self.latent_mean

        #ana_field = self.AE.decode(ana_latent)
        return ana_latent

    def get_latent(self,field):
        return self.AE.encode(field).mode()

class diffDA(object):
    def __init__(self, **model_params) -> None:
        super().__init__()

        print('model_params:',model_params)
        self.optimizer_params = model_params.get('optimizer', {})
        self.scheduler_params = model_params.get('lr_scheduler', {})

        self.kernel = diffDA_model()

        self.best_loss = 9999999
        self.criterion_mae = nn.L1Loss()
        self.criterion_mse = nn.MSELoss()

        if utils.is_dist_avail_and_initialized():
            self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
            if self.device == torch.device('cpu'):
                raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    def train(self, train_data_loader, valid_data_loader, logger, args):
        
        train_step = len(train_data_loader)
        valid_step = len(valid_data_loader)
        self.optimizer = torch.optim.AdamW(self.kernel.parameters(), lr=0.0001, weight_decay=0)
        with h5py.File("/path/to/latent_mean_std_h5file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        ema_helper = EMAHelper()
        if utils.get_world_size() > 1 and utils.get_rank() == 0:
            ema_helper.register(self.kernel.module.diff)
        elif utils.get_world_size() == 1:
            ema_helper.register(self.kernel.diff)
        t_step = 0
        for epoch in range(args.max_epoch):
            begin_time = time.time()
            self.kernel.train()
            
            for step, batch_data in enumerate(train_data_loader):
                
                bkg_latent = batch_data[0].float().to(self.device)
                ana_latent = batch_data[1].float().to(self.device)

                self.optimizer.zero_grad()

                noise_prediction, noise = self.kernel(bkg_latent, ana_latent)
                
                loss = self.criterion_mse(noise_prediction, noise)
                loss.backward()
                clip_grad_norm_(self.kernel.parameters(), max_norm=1)
                self.optimizer.step()

                if utils.get_world_size() > 1 and utils.get_rank() == 0:
                    ema_helper.update(self.kernel.module.diff)
                elif utils.get_world_size() == 1:
                    ema_helper.update(self.kernel.diff)
                #self.scheduler.step()
                t_step = t_step + 1
                
                if ((step + 1) % 100 == 0) | (step+1 == train_step):
                    logger.info(f"Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.optimizer.param_groups[0]['lr']}], loss:[{loss.item()}]")
                
                if (t_step+1)%20000==0:
                    if utils.get_world_size() > 1 and utils.get_rank() == 0:
                        torch.save(ema_helper.state_dict(), f'{args.rundir}/models_{t_step}_mode.pth')
                    elif utils.get_world_size() == 1:
                        torch.save(ema_helper.state_dict(), f'{args.rundir}/models_{t_step}_mode.pth')
                    logger.info(f'New model appears in epoch {epoch+1}.')
        
                    
            self.kernel.eval()
            total_loss = 0
            with torch.no_grad():
                for step, batch_data in enumerate(valid_data_loader):

                    bkg_latent = batch_data[0].float().to(self.device)
                    ana_latent = batch_data[1].float().to(self.device)

                    noise_prediction, noise = self.kernel(bkg_latent, ana_latent)
                
                    loss = self.criterion_mse(noise_prediction, noise).item()
                    total_loss += loss
                    
                    
                    if ((step + 1) % 100 == 0) | (step+1 == valid_step):
                        #de_loss1 = self.criterion_mse(ana_latent,tru_latent)
                        logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}]')
        
            if (total_loss/valid_step) < self.best_loss:
                if utils.get_world_size() > 1 and utils.get_rank() == 0:
                    #torch.save(self.kernel.module.diff.state_dict(), f'{args.rundir}/best_diff_model_norm_1000.pth')
                    torch.save(ema_helper.state_dict(), f'{args.rundir}/48ema_model_norm_vae_mode.pth')
                elif utils.get_world_size() == 1:
                    #torch.save(self.kernel.diff.state_dict(), f'{args.rundir}/best_diff_model_norm_.pth')
                    torch.save(ema_helper.state_dict(), f'{args.rundir}/48ema_model_norm_vae_mode.pth')
                logger.info(f'New best model appears in epoch {epoch+1}.')
                self.best_loss = total_loss/valid_step
            
            logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]')
    
    def test(self,test_data_loader, logger, args):
        
        test_step = len(test_data_loader)
        #data_mean, data_std = test_data_loader.dataset.get_meanstd()
        #self.data_std = data_std.to(self.device)
        self.kernel.eval()
        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0

            for step, batch_data in enumerate(test_data_loader):

                bkg_latent = batch_data[0].to(self.device)
                ana_latent = batch_data[1].to(self.device)

                bkg_latent = (bkg_latent*self.latent_std+self.latent_mean).float().to(self.device)
                ana_latent = (ana_latent*self.latent_std+self.latent_mean).float().to(self.device)
                recons_bkg = self.kernel.AE.decode(bkg_latent)
                recons_ana = self.kernel.AE.decode(ana_latent)
                print(self.criterion_mse(recons_ana,recons_bkg))
                #print(WRMSE(recons_ana,recons_bkg,self.data_std))

                ana_sampled,ana_field = self.kernel.sample(bkg_latent)
                
                loss = self.criterion_mse(ana_latent, ana_sampled).item()
                total_loss += loss

                if ((step + 1) % 1 == 0) | (step+1 == test_step):
                    #print(x_pred.shape,x_target.shape)
                    print(bkg_latent.shape)
                    self.save_sample(step,ana_sampled,ana_latent)
                    #self.save_sample(step,ana_field)
                    
                    logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}]')
    
    def test_da(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)
        self.data_mean = data_mean.to(self.device)
        #print('data_std:',self.data_std)
        #print('data_mean:',self.data_mean)
        self.kernel.eval()
        with h5py.File("/path/to/latent_mean_std_h5file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)

        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0
            fengwu_mse = 0
            fengwu_mae = 0
            fengwu_rmse = 0
            for step, batch_data in enumerate(test_data_loader):
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                print(batch_data[1])
                truth_ana = batch_data[0][2].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field
                fengwu_mse += self.criterion_mse(ana_field,bkg_field).item()
                fengwu_mae += self.criterion_mae(ana_field,bkg_field).item()
                fengwu_rmse += WRMSE(ana_field,bkg_field, self.data_std)
                
                mse3 = self.criterion_mse(ana_field_,bkg_field_).item()
                rmse3 = WRMSE(ana_field_,bkg_field_,self.data_std)
                
                if utils.get_world_size() > 1:
                    bkg_latent = self.kernel.module.get_latent(bkg_field.float())
                    ana_latent = self.kernel.module.get_latent(ana_field.float())
                else:
                    bkg_latent = self.kernel.get_latent(bkg_field.float())
                    ana_latent = self.kernel.get_latent(ana_field.float())
                
                
                if utils.get_world_size() > 1:
                    recons_bkg = self.kernel.module.AE.decode(bkg_latent)
                    recons_ana = self.kernel.module.AE.decode(ana_latent)
                    
                else:
                    recons_bkg = self.kernel.AE.decode(bkg_latent)
                    recons_ana = self.kernel.AE.decode(ana_latent)
                    
                rmse1 = WRMSE(ana_field_,recons_ana,self.data_std)
                mse1 = self.criterion_mse(ana_field_,recons_ana).item()

                mse2 = self.criterion_mse(bkg_field_,recons_bkg).item()
                rmse2 = WRMSE(bkg_field_,recons_bkg, self.data_std)
                bkg_latent_s = bkg_latent
                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                bkg_latent = bkg_latent.float().to(self.device)
                pool = []
                latent_pool = []
                
                for i in range(32):
                    if utils.get_world_size() > 1:
                        ana_sampled_latent = self.kernel.module.sample(bkg_latent)
                    else:
                        ana_sampled_latent = self.kernel.sample(bkg_latent)

                    ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                    if utils.get_world_size() > 1:
                        ana_sampled = self.kernel.module.AE.decode(ana_sampled_latent)
                    else:
                        ana_sampled = self.kernel.AE.decode(ana_sampled_latent)

                    latent_pool.append(ana_sampled_latent)

                    l_mse = self.criterion_mse(ana_sampled_latent,ana_latent).item()
                    y_pred = ana_sampled 
                    y_target = ana_field 
                    pool.append(ana_sampled)
                    mae = self.criterion_mae(y_pred, y_target).item()
                    mse = self.criterion_mse(y_pred, y_target).item()
                    
                    rmse = WRMSE(y_pred, y_target, self.data_std)
                    logger.info(f'{i},{self.criterion_mse(y_pred, y_target).item()}, {rmse[0]}, {rmse[1]}, {rmse[2]}, {rmse[3]}, {rmse[11]}, {l_mse}')
                
                #score on data assimilation
                scores_da = cal_da_score(pool,y_target,self.data_std)
                h_42 = F.interpolate(batch_data[0][3], size=(128,256), mode='bilinear').to(self.device)
                h_48_era5 =  F.interpolate(batch_data[0][4], size=(128,256), mode='bilinear').to(self.device)
                scores_fore = cal_fore_score(pool, h_42, h_48_era5 ,self.data_std,args)
                scores_phys = cal_phys_score(pool,self.data_std.unsqueeze(1).unsqueeze(1),self.data_mean.unsqueeze(1).unsqueeze(1))
                
                scores = np.vstack([scores_da.cpu().numpy(),scores_fore.cpu().numpy(),scores_phys])
                scores = np.transpose(scores)
                if utils.get_rank()==0:
                    print(torch.from_numpy(scores))
                
                
                #single reward
                DA_wins_index,DA_loses_index, Fore_wins_index,Fore_loses_index, Phys_wins_index,Phys_loses_index = singleselect(scores)
                logger.info(f"Single Index: {DA_wins_index},{DA_loses_index}, {Fore_wins_index},{Fore_loses_index}, {Phys_wins_index},{Phys_loses_index}")

                wins_index,loses_index = DA_wins_index,DA_loses_index
                logger.info(f"single DA wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("single DA No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                selected_pairs = pairs[:5]
                for i in range(len(selected_pairs)):
                    pair = selected_pairs[i]
                    win_latent = latent_pool[int(pair[0])]
                    lose_latent = latent_pool[int(pair[1])]
                    with h5py.File("validDare_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())
                
                wins_index,loses_index = Fore_wins_index,Fore_loses_index
                logger.info(f"single Fore wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("single Fore No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                selected_pairs = pairs[:5]
                for i in range(len(selected_pairs)):
                    pair = selected_pairs[i]
                    win_latent = latent_pool[int(pair[0])]
                    lose_latent = latent_pool[int(pair[1])]
                    with h5py.File("validForere_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())
                
                wins_index,loses_index = Phys_wins_index,Phys_loses_index
                logger.info(f"single Phys wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("single Phys No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                selected_pairs = pairs[:5]
                for i in range(len(selected_pairs)):
                    pair = selected_pairs[i]
                    win_latent = latent_pool[int(pair[0])]
                    lose_latent = latent_pool[int(pair[1])]
                    with h5py.File("validPhysre_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())

                #Double reward
                da_fore_wins_index,da_fore_loses_index, da_phys_wins_index,da_phys_loses_index, fore_phys_wins_index,fore_phys_loses_index = doubleselect(scores)

                wins_index,loses_index = da_fore_wins_index,da_fore_loses_index
                logger.info(f"double DA-Fore wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("double DA-Fore No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                if len(pairs)>24:
                    selected_pairs = pairs[:int(len(pairs)//3)]
                else:
                    selected_pairs = pairs
                for i in range(len(selected_pairs)):
                    
                    with h5py.File("validDaForere_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    pair = selected_pairs[i]
                    win_latent = latent_pool[pair[0]]
                    lose_latent = latent_pool[pair[1]]
                    with h5py.File("validDaForere_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "a") as f:
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())

                wins_index,loses_index = da_phys_wins_index,da_phys_loses_index
                logger.info(f"double DA-Phys wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("double DA-Phys No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                if len(pairs)>24:
                    selected_pairs = pairs[:int(len(pairs)//3)]
                else:
                    selected_pairs = pairs
                for i in range(len(selected_pairs)):
                    
                    with h5py.File("validDaPhysre_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    pair = selected_pairs[i]
                    win_latent = latent_pool[pair[0]]
                    lose_latent = latent_pool[pair[1]]
                    with h5py.File("validDaPhysre_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "a") as f:
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())
                
                wins_index,loses_index = fore_phys_wins_index,fore_phys_loses_index
                logger.info(f"double Fore-Phys wins_index,loses_index: {wins_index},{loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("double Fore-Phys No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                if len(pairs)>24:
                    selected_pairs = pairs[:int(len(pairs)//3)]
                else:
                    selected_pairs = pairs
                for i in range(len(selected_pairs)):
                    
                    with h5py.File("validForePhysre_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    pair = selected_pairs[i]
                    win_latent = latent_pool[pair[0]]
                    lose_latent = latent_pool[pair[1]]
                    with h5py.File("validForePhysre_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "a") as f:
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())
                
                #3 reward
                wins_index,loses_index = select(scores)
                logger.info(f"wins_index,loses_index: {wins_index}, {loses_index}")
                if len(wins_index)*len(loses_index)==0:
                    logger.info("No selected data")
                    continue 
                pairs = np.array(np.meshgrid(wins_index, loses_index)).T.reshape(-1, 2)
                np.random.shuffle(pairs)
                if len(pairs)>24:
                    selected_pairs = pairs[:int(len(pairs)//3)]
                else:
                    selected_pairs = pairs

                for i in range(len(selected_pairs)):
                    
                    with h5py.File("valid3re_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "w") as f:
                        f.create_dataset("bkg_latent", data=bkg_latent_s.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    pair = selected_pairs[i]
                    win_latent = latent_pool[pair[0]]
                    lose_latent = latent_pool[pair[1]]
                    with h5py.File("valid3re_pair_data_{}_{}.h5".format(step*utils.get_world_size()+utils.get_rank(),i), "a") as f:
                        f.create_dataset("win_sampled", data=win_latent.detach().cpu().numpy())
                        f.create_dataset("lose_sampled", data=lose_latent.detach().cpu().numpy())

                loss = self.criterion_mse(ana_sampled_latent,ana_latent).item()

                total_loss += mse
                total_mae += mae
                total_mse += mse
                total_rmse += rmse
                
    def test_da_guide(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)

        
        with h5py.File("/path/to/latent_mean_std_h5file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        def A(x):
            x =  (x*self.latent_std+self.latent_mean).float().to(self.device)
            return self.kernel.AE.decode(x)

        sampling = GuidanceSampling(A=A, std = 0.1, sde = self.kernel.diff, device = self.device)

        self.kernel.eval()
        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0
            for step, batch_data in enumerate(test_data_loader):
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                
                truth_ana = batch_data[0][2].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field
                
                
                bkg_latent = self.kernel.get_latent(bkg_field)
                ana_latent = self.kernel.get_latent(ana_field)

                recons_bkg = self.kernel.AE.decode(bkg_latent)
                
                mse3 = self.criterion_mse(recons_bkg,bkg_field_).item()
                rmse1 = WRMSE(recons_bkg,ana_field_,self.data_std)
                mse1 = self.criterion_mse(recons_bkg,ana_field_).item()

                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                #bkg_latent = (bkg_latent/3+1)/2
                bkg_latent = bkg_latent.float().to(self.device)

                #get the guidance (simulated observation)
                obs_mask = torch.rand(ana_field.shape, device=self.device) >= 0.99
                observation = obs_mask*ana_field


                ana_sampled_latent = sampling.sample(shape=(1,),guidance=observation,c=bkg_latent,obs_mask=obs_mask)
                ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                ana_sampled = self.kernel.AE.decode(ana_sampled_latent)
                
                y_pred = ana_sampled 
                y_target = ana_field 

                mae = self.criterion_mae(y_pred, y_target).item()
                mse = self.criterion_mse(y_pred, y_target).item()
                rmse = WRMSE(y_pred, y_target, self.data_std)
                loss = self.criterion_mse(ana_sampled_latent,ana_latent).item()

                total_loss += mse
                total_mae += mae
                total_mse += mse
                total_rmse += rmse

                if ((step + 1) % 1 == 0) | (step+1 == test_step):
                    logger.info(f'Valid step:[{step+1}/{test_step}], MAE:[{mae}], MSE:[{mse}] RMSE:[{rmse}], MSE1:[{mse1}] RMSE1:[{rmse1}], latent_loss [{loss}] {rmse}')

        logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
        logger.info(f'Average RMSE:[{total_rmse/test_step}]')


    def test_da_hard(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)
        self.data_mean = data_mean.to(self.device)
        print(self.data_std.shape)
        self.kernel.eval()
        with h5py.File("/path/to/latent_mean_std_h5file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        with torch.no_grad():
            t_da_mse = 0
            t_da_mae = 0
            t_da_rmse = 0
            t_fore_mse = 0
            t_fore_mae = 0
            t_fore_rmse = 0
            t_geo_d = 0

            for step, batch_data in enumerate(test_data_loader):
                
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                print(batch_data[1])
                truth_ana = batch_data[0][2].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field
                print(WRMSE(bkg_field,ana_field,self.data_std)[11])
                
                if utils.get_world_size()>1:
                    bkg_latent = self.kernel.module.get_latent(bkg_field)
                    ana_latent = self.kernel.module.get_latent(ana_field)
                else:
                    bkg_latent = self.kernel.get_latent(bkg_field)
                    ana_latent = self.kernel.get_latent(ana_field)

                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                bkg_latent = bkg_latent.float().to(self.device)
                
                pool = []
                latent_pool = []
                step_da_mse = 0
                step_da_mae = 0
                step_da_rmse = 0
                step_fore_mse = 0
                step_fore_mae = 0
                step_fore_rmse = 0
                step_geo_d = 0
                for i in range(1):
                    
                    ratio = 0.01
                    obs_mask = torch.rand(ana_field.shape, device=self.device) >= 0.99
                    #obs_mask = torch.rand((128,256), device=self.device) >= (1-330/(128*256))
                    #obs_mask = (obs_mask.expand(69,-1,-1)).unsqueeze(0)
                    
                    def A(x,t):
                        x =  (x*self.latent_std+self.latent_mean).float().to(self.device)
                        if utils.get_world_size()>1:
                            observation_latent = self.kernel.module.get_latent(ana_field)
                            x = self.kernel.module.AE.decode(x)
                            observation_latent = (observation_latent-self.latent_mean)/(self.latent_std+1e-10)
                            noised_observation_latent = self.kernel.module.diff.noising(observation_latent,t).float().to(self.device)
                            noised_observation = self.kernel.module.AE.decode(noised_observation_latent)
                            temp_ana = x*(1-obs_mask.float())+obs_mask.float()*noised_observation
                            temp_ana_latent = self.kernel.module.get_latent(temp_ana)
                            temp_ana_latent = (temp_ana_latent-self.latent_mean)/(self.latent_std+1e-10)
            
                            return temp_ana_latent.float()
                        else:
                            observation_latent = self.kernel.get_latent(ana_field)
                            x = self.kernel.AE.decode(x)
                            observation_latent = (observation_latent-self.latent_mean)/(self.latent_std+1e-10)
                            noised_observation_latent = self.kernel.diff.noising(observation_latent,t).float().to(self.device)
                            noised_observation = self.kernel.AE.decode(noised_observation_latent)
                            #noised_observation = self.kernel.diff.noising(ana_field,t).float().to(self.device)
                            temp_ana = x*(1-obs_mask.float())+obs_mask.float()*noised_observation
                            temp_ana_latent = self.kernel.get_latent(temp_ana)
                            temp_ana_latent = (temp_ana_latent-self.latent_mean)/(self.latent_std+1e-10)
            
                            return temp_ana_latent.float()
                            
                    if utils.get_world_size()>1:
                        ana_sampled_latent = self.kernel.module.sample_hard(A,bkg_latent)
                    else:
                        ana_sampled_latent = self.kernel.sample_hard(A,bkg_latent)
                    ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                    latent_pool.append(ana_sampled_latent)
                    if utils.get_world_size()>1:
                        ana_sampled = self.kernel.module.AE.decode(ana_sampled_latent)
                    else:
                        ana_sampled = self.kernel.AE.decode(ana_sampled_latent)
                    pool.append(ana_sampled)
                    
                    y_pred = ana_sampled 
                    y_target = ana_field 
                    
                    step_da_rmse += WRMSE(y_pred, y_target, self.data_std)/4
                    step_da_mse += self.criterion_mse(y_pred, y_target).item()/4
                    step_da_mae += self.criterion_mae(y_pred, y_target).item()/4
                    #logger.info(f'{ratio}, {i}, {self.criterion_mse(y_pred, y_target).item()}')
                    #logger.info(f'{ratio}, {i}, {rmse[0].item()}, {rmse[1].item()}, {rmse[2].item()}, {rmse[3].item()}, {rmse[11].item()}')
                    h_42 = F.interpolate(batch_data[0][3], size=(128,256), mode='bilinear').to(self.device)
                    h_48_era5 =  F.interpolate(batch_data[0][4], size=(128,256), mode='bilinear').to(self.device)
                    da_forecast = cal_fore_score_(ana_sampled, h_42, h_48_era5, args)
                    step_fore_rmse += WRMSE(da_forecast, h_48_era5, self.data_std)/4
                    step_fore_mse += self.criterion_mse(da_forecast, h_48_era5).item()/4
                    step_fore_mae += self.criterion_mae(da_forecast, h_48_era5).item()/4
                    
                    step_geo_d += cal_geo(ana_sampled,self.data_std.unsqueeze(1).unsqueeze(1),self.data_mean.unsqueeze(1).unsqueeze(1))/4
                
                    
                    
                t_da_mae += step_da_mae
                t_da_mse += step_da_mse
                t_da_rmse += step_da_rmse
                t_fore_mae += step_fore_mae
                t_fore_mse += step_fore_mse
                t_fore_rmse += step_fore_rmse
                t_geo_d += step_geo_d
                if (step+1)%1==0:
                    logger.info(f'Average DA-MAE:[{t_da_mae}/{step+1}], DA-MSE:[{t_da_mse}/{step+1}], Average DA-RMSE:[{t_da_rmse}/{step+1}]')
                    logger.info(f'Average Fore-MAE:[{t_fore_mae}/{step+1}], Fore-MSE:[{t_fore_mse}/{step+1}], Average Fore-RMSE:[{t_fore_rmse}/{step+1}]')
                    logger.info(f'Average Geo [{t_geo_d}/{step+1}]')
        
        

        logger.info(f'Average DA-MAE:[{t_da_mae/test_step}], DA-MSE:[{t_da_mse/test_step}], DA-RMSE:[{t_da_rmse}/{test_step}]')
        logger.info(f'Average Fore-MAE:[{t_fore_mae/test_step}], Fore-MSE:[{t_fore_mse/test_step}], Fore-RMSE:[{t_fore_rmse}/{test_step}]')
        logger.info(f'Average Geo [{t_geo_d}/{test_step}]')

    
    def test_da_noguide(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)
        self.data_mean = data_mean.to(self.device)
        print(self.data_std.shape)
        self.kernel.eval()
        with h5py.File("/path/to/latent_mean_std_h5file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        with torch.no_grad():
            t_da_mse = 0
            t_da_mae = 0
            t_da_rmse = 0
            t_fore_mse = 0
            t_fore_mae = 0
            t_fore_rmse = 0
            t_geo_d = 0

            for step, batch_data in enumerate(test_data_loader):
                
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                print(batch_data[1])
                truth_ana = batch_data[0][2].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field
                print(WRMSE(bkg_field,ana_field,self.data_std)[11])
                
                if utils.get_world_size()>1:
                    bkg_latent = self.kernel.module.get_latent(bkg_field)
                    ana_latent = self.kernel.module.get_latent(ana_field)
                else:
                    bkg_latent = self.kernel.get_latent(bkg_field)
                    ana_latent = self.kernel.get_latent(ana_field)

                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                bkg_latent = bkg_latent.float().to(self.device)
                
                pool = []
                latent_pool = []
                step_da_mse = 0
                step_da_mae = 0
                step_da_rmse = 0
                step_fore_mse = 0
                step_fore_mae = 0
                step_fore_rmse = 0
                step_geo_d = 0
                for i in range(1):
                    
                    ratio = 0.01
                    obs_mask = torch.rand(ana_field.shape, device=self.device) >= 0.99
                    #obs_mask = torch.rand((128,256), device=self.device) >= (1-330/(128*256))
                    #obs_mask = (obs_mask.expand(69,-1,-1)).unsqueeze(0)
                    
                    if utils.get_world_size() > 1:
                        ana_sampled_latent = self.kernel.module.sample(bkg_latent)
                    else:
                        ana_sampled_latent = self.kernel.sample(bkg_latent)
                            
                    ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                    latent_pool.append(ana_sampled_latent)
                    if utils.get_world_size()>1:
                        ana_sampled = self.kernel.module.AE.decode(ana_sampled_latent)
                    else:
                        ana_sampled = self.kernel.AE.decode(ana_sampled_latent)
                    pool.append(ana_sampled)
                    
                    y_pred = ana_sampled 
                    y_target = ana_field 
                    
                    step_da_rmse += WRMSE(y_pred, y_target, self.data_std)/4
                    step_da_mse += self.criterion_mse(y_pred, y_target).item()/4
                    step_da_mae += self.criterion_mae(y_pred, y_target).item()/4
                    #logger.info(f'{ratio}, {i}, {self.criterion_mse(y_pred, y_target).item()}')
                    #logger.info(f'{ratio}, {i}, {rmse[0].item()}, {rmse[1].item()}, {rmse[2].item()}, {rmse[3].item()}, {rmse[11].item()}')
                    h_42 = F.interpolate(batch_data[0][3], size=(128,256), mode='bilinear').to(self.device)
                    h_48_era5 =  F.interpolate(batch_data[0][4], size=(128,256), mode='bilinear').to(self.device)
                    da_forecast = cal_fore_score_(ana_sampled, h_42, h_48_era5, args)
                    step_fore_rmse += WRMSE(da_forecast, h_48_era5, self.data_std)/4
                    step_fore_mse += self.criterion_mse(da_forecast, h_48_era5).item()/4
                    step_fore_mae += self.criterion_mae(da_forecast, h_48_era5).item()/4
                    
                    step_geo_d += cal_geo(ana_sampled,self.data_std.unsqueeze(1).unsqueeze(1),self.data_mean.unsqueeze(1).unsqueeze(1))/4
                
                    
                    
                t_da_mae += step_da_mae
                t_da_mse += step_da_mse
                t_da_rmse += step_da_rmse
                t_fore_mae += step_fore_mae
                t_fore_mse += step_fore_mse
                t_fore_rmse += step_fore_rmse
                t_geo_d += step_geo_d
                if (step+1)%1==0:
                    logger.info(f'Average DA-MAE:[{t_da_mae}/{step+1}], DA-MSE:[{t_da_mse}/{step+1}], Average DA-RMSE:[{t_da_rmse}/{step+1}]')
                    logger.info(f'Average Fore-MAE:[{t_fore_mae}/{step+1}], Fore-MSE:[{t_fore_mse}/{step+1}], Average Fore-RMSE:[{t_fore_rmse}/{step+1}]')
                    logger.info(f'Average Geo [{t_geo_d}/{step+1}]')
        
        

        logger.info(f'Average DA-MAE:[{t_da_mae/test_step}], DA-MSE:[{t_da_mse/test_step}], DA-RMSE:[{t_da_rmse}/{test_step}]')
        logger.info(f'Average Fore-MAE:[{t_fore_mae/test_step}], Fore-MSE:[{t_fore_mse/test_step}], Fore-RMSE:[{t_fore_rmse}/{test_step}]')
        logger.info(f'Average Geo [{t_geo_d}/{test_step}]')