
import numpy as np
import torch
import torch.nn as nn
import os
import sys 
import time
import math  
import CRPS.CRPS as pscore

from utils.tools import adjust_learning_rate
from torch import optim
from logging import Logger
from data_provider.data_loader_epf import Dataset_Custom
from torch.utils.data import DataLoader
from utils.tools import EarlyStopping
from utils.metrics import  metric_mask

from model9_NS_transformer.ns_models import    timeXer_v3_multi
from model9_NS_transformer.exp.exp_basic import Exp_Basic
from model9_NS_transformer.diffusion_models import diffuMTS
from model9_NS_transformer.diffusion_models.diffusion_utils import q_sample, p_sample_loop


def ccc(id, pred, true):
    res_box = np.zeros(len(true))
    for i in range(len(true)):
        res = pscore(pred[i], true[i]).compute()
        res_box[i] = res[0]
 
    return res_box


def log_normal(x, mu, var):
    """Logarithm of normal distribution with mean=mu and variance=var
       log(x|μ, σ^2) = loss = -0.5 * Σ log(2π) + log(σ^2) + ((x - μ)/σ)^2

    Args:
       x: (array) corresponding array containing the input
       mu: (array) corresponding array containing the mean
       var: (array) corresponding array containing the variance

    Returns:
       output: (array/float) depending on average parameters the result will be the mean
                            of all the sample losses or an array with the losses per sample
    """
    eps = 1e-8
    if eps > 0.0:
        var = var + eps
    return torch.mean(torch.pow(x - mu, 2))

class Exp_Main(Exp_Basic):
    def __init__(self, args):
        super(Exp_Main, self).__init__(args)
        self._init_data()
        
        
    def _build_model(self):
        model = diffuMTS.Model(self.args, self.device).float()
        cond_pred_model = timeXer_v3_multi.Model(self.args).float()
        cond_pred_model_train = timeXer_v3_multi.Model(self.args).float()
        return model, cond_pred_model, cond_pred_model_train

    def set_logger(self, log: Logger):
        self.logger = log

    def _init_data(self):
        data_size = [self.args.history, self.args.label_len, self.args.pred_window]
        self.train_dataset = Dataset_Custom(
           
            root_path=self.args.root_path,
            flag="train",
            size=data_size,
            data_path=self.args.data_path,
            timeenc=self.args.timeenc,
            mask_covar_ratio=self.args.mask_covar_ratio,
            mask_target_ratio=self.args.mask_target_ratio,
            down_sample=self.args.downsample,
        )
        self.train_loader = DataLoader(self.train_dataset, 
                                       self.args.batch_size, 
                                       shuffle=False, 
                                       drop_last=False, 
                                       num_workers=self.thread)

        self.test_dataset = Dataset_Custom(
            root_path=self.args.root_path,
            flag="test",
            size=data_size,
            data_path=self.args.data_path,
            timeenc=self.args.timeenc,
            mask_covar_ratio=self.args.mask_covar_ratio,
            mask_target_ratio=self.args.mask_target_ratio,
            down_sample=self.args.downsample,
        )
        self.test_loader = DataLoader(self.test_dataset, 
                                      self.args.test_batch_size, 
                                      shuffle=False, 
                                      drop_last=False, 
                                      num_workers=self.thread)

        self.vali_dataset = Dataset_Custom(
            root_path=self.args.root_path,
            flag="val",
            size=data_size,
            data_path=self.args.data_path,
            timeenc=self.args.timeenc,
            mask_covar_ratio=self.args.mask_covar_ratio,
            mask_target_ratio=self.args.mask_target_ratio,
            down_sample=self.args.downsample,
        )
        self.val_loader = DataLoader(self.vali_dataset, 
                                     self.args.batch_size, 
                                     shuffle=False, 
                                     drop_last=False, 
                                     num_workers=self.thread)

        self.data_dict = {
            "train": [self.train_dataset, self.train_loader],
            "test": [self.test_dataset, self.test_loader],
            "val": [self.vali_dataset, self.val_loader]
        }

    def _get_data(self, flag):
        return self.data_dict[flag]

    def _select_optimizer(self, mode='Model'):
        if mode == 'Model':
            # model_optim = optim.Adam([
            #     {'params': self.model.parameters()}, 
            #     {'params': self.cond_pred_model.parameters()}
            # ], lr=self.args.learning_rate)

            model_optim = optim.Adam([
                {'params': self.model.parameters(), 'lr': self.args.learning_rate * 1}, 
                {'params': self.cond_pred_model.parameters(), 'lr': self.args.learning_rate * 1}
            ])

        elif mode == 'Cond':
            model_optim = optim.Adam(self.cond_pred_model_train.parameters(), lr=self.args.learning_rate_Cond)
        else:
            model_optim = None
        return model_optim

    def _select_criterion_mask(self, loss = "L2"):
        def mse_mask(x, x_hat, mask):
            if torch.sum(mask) <= 0:
                return torch.zeros(size=1)
            return torch.sum(torch.square(x-x_hat)*mask)/torch.sum(mask)
        def mae_mask(x, x_hat, mask):
            if torch.sum(mask) <= 0:
                return torch.zeros(size=1)
            return torch.sum(torch.abs(x-x_hat)*mask)/torch.sum(mask)
        
        def total_mask(x, x_hat, mask):
            return mae_mask(x, x_hat, mask) + mse_mask(x, x_hat, mask)
        if loss == "L2":
            return mse_mask
        if loss == 'L1':
            return mae_mask
        if loss == "mix":
            return total_mask
        return mse_mask
    
    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        cond_loss = []
        self.model.eval()
        self.cond_pred_model.eval()
        target_num = self.args.target_num
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x = batch_x * batch_x_mask.to(self.device)
                if self.args.pseudo_len > 0:
                    batch_x[:, -self.args.pseudo_len:, -target_num:] = 0

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_y_mask = batch_y_mask.to(self.device)
                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                
                n = batch_x.size(0)
                t = torch.randint(
                    low=0, high=self.model.num_timesteps, size=(n // 2 + 1,)
                ).to(self.device)
                t = torch.cat([t, self.model.num_timesteps - 1 - t], dim=0)[:n]

                y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark,mode=self.args.mode, target=target_num )
                
                # assert torch.isnan(y_0_hat_batch).sum() == 0, self.logger.error(y_0_hat_batch)
                
                #loss_vae = log_normal(batch_y[:, -self.args.pred_len:, -1:], y_0_hat_batch[:, :, -1:], torch.from_numpy(np.array(1)))
                loss_vae = criterion(batch_y[:, -self.args.pred_len:, -target_num:], y_0_hat_batch[:, :, -target_num:], batch_y_mask[:, -self.args.pred_len:, -target_num:] )

                loss_vae_all = loss_vae

                y_T_mean = y_0_hat_batch
                e = torch.randn_like(batch_y[:, :, -target_num:]).to(self.device)

                y_t_batch = q_sample(batch_y[:, :, -target_num:], y_T_mean[:, :, -target_num:], self.model.alphas_bar_sqrt,
                                        self.model.one_minus_alphas_bar_sqrt, t, noise=e)
                output = self.model(batch_x, batch_x_mark, batch_y[:, :, -target_num:], y_t_batch[:, :, -target_num:], y_0_hat_batch[:, :, -target_num:], t)

                diff_loss = (e[:, -self.args.pred_len:, :] - output[:, -self.args.pred_len:, :]).square().mean()
                loss = diff_loss + self.args.k_cond * loss_vae_all
                loss = loss.detach().cpu()
                cond_loss.append(loss_vae_all.detach().cpu().numpy())
                total_loss.append(diff_loss.detach().cpu().numpy())
        total_loss = np.average(total_loss)
        cond_loss = np.average(cond_loss)
        self.model.train()
        return total_loss, cond_loss



    def train(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        path2 = os.path.join(path, 'best_cond_model_dir/')
        path2_load = path + '/' + 'checkpoint.pth'

        if not os.path.exists(path):
            os.makedirs(path)

        if not os.path.exists(path2):
            os.makedirs(path2)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True, log=self.logger)

        model_optim = self._select_optimizer()

        criterion = self._select_criterion_mask()
        target_num = self.args.target_num
        for epoch in range(self.args.train_epochs):
            # Training the diffusion part
            epoch_time = time.time()
            print(f"start epoch {epoch}")
            iter_count = 0
            train_loss = []
            train_vae_loss = []
            self.model.train()
            self.cond_pred_model.train()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
             
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x = batch_x * batch_x_mask.to(self.device)
                if self.args.pseudo_len > 0:
                    batch_x[:, -self.args.pseudo_len:, -1:] = 0
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_y_mask = batch_y_mask.to(self.device)
                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_window:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                
                n = batch_x.size(0)
                t = torch.randint(
                    low=0, high=self.model.num_timesteps, size=(n // 2 + 1,)
                ).to(self.device)
                t = torch.cat([t, self.model.num_timesteps - 1 - t], dim=0)[:n]
                
                y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark, 
                                                     mode=self.args.mode, target=target_num)
                
                # assert torch.isnan(y_0_hat_batch).sum() == 0, self.logger.error(y_0_hat_batch)
                
                # loss_vae = log_normal(batch_y[:, -self.args.pred_len:, -1:], y_0_hat_batch[:, :, -1:], torch.from_numpy(np.array(1)))
                loss_vae = criterion(batch_y[:, -self.args.pred_len:, -target_num:], 
                                     y_0_hat_batch[:, :, -target_num:], 
                                     batch_y_mask[:, -self.args.pred_len:, -target_num:], )
                # loss_vae = (batch_y[:, -self.args.pred_len:, -1:] - y_0_hat_batch[:, :, -1:]).square().mean()

                loss_vae_all = loss_vae

                y_T_mean = y_0_hat_batch
                e = torch.randn_like(batch_y[:, :, -target_num:]).to(self.device)

                y_t_batch = q_sample(batch_y[:, :, -target_num:], y_T_mean[:, :, -target_num:], self.model.alphas_bar_sqrt,
                                        self.model.one_minus_alphas_bar_sqrt, t, noise=e)

                output = self.model(batch_x, batch_x_mark, batch_y[:, :, -target_num:], 
                                    y_t_batch[:, :, -target_num:], y_0_hat_batch[:, :, -target_num:], t)

                diff_loss = (e - output).square().mean()
                loss = diff_loss +  (self.args.k_cond_decay**epoch) * self.args.k_cond*loss_vae_all

                

                train_vae_loss.append(loss_vae_all.item() )
                train_loss.append(diff_loss.item())

                if (i + 1) % 100 == 0:
                    self.logger.info("\titers: {0}, epoch: {1} | loss: {2:.7f} vae loss {3:.7f}".format(
                        i + 1, epoch + 1, diff_loss.item(), loss_vae_all.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    self.logger.info('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
            
                loss.backward()
                model_optim.step()

             
            
            train_loss = np.average(train_loss)
            vali_loss, vali_cond_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss, test_cond_loss = self.vali(test_data, test_loader, criterion)


            self.logger.info("Epoch: {0} | Train Loss: {1:.7f}  Vali Loss: {2:.7f} Test Loss: {3:.7f}".format(
                             epoch + 1, train_loss, vali_loss, test_loss)) 
            self.logger.info(f"Epoch: {epoch + 1} | Train cond: {np.average(train_vae_loss):.7f}, Vali cond: {vali_cond_loss:.7f}, "  + 
                             f"Test cond: {test_cond_loss:.7f}")
            self.logger.info("Epoch: {} | cost time: {}".format(epoch + 1, time.time() - epoch_time))

            #adjust_learning_rate(optimizer=model_optim, epoch=epoch, args=self.args, log=self.logger )

            early_stopping(vali_loss+vali_cond_loss, test_loss+test_cond_loss, self.model, self.cond_pred_model, path)


            if early_stopping.early_stop:
                self.logger.info("Early stopping")
                break

        # best_model_path = path + '/' + 'checkpoint.pth'
        # self.model.load_state_dict(torch.load(best_model_path, map_location=self.device))

        return self.model


    def vali_cond(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        self.cond_pred_model.eval()
        target_num = self.args.target_num
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x = batch_x * batch_x_mask.to(self.device)
                if self.args.pseudo_len > 0:
                    batch_x[:, -self.args.pseudo_len:, -target_num:] = 0
        
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
              
                y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark, mode=self.args.mode, target=target_num)
                
                # assert torch.isnan(y_0_hat_batch).sum() == 0, self.logger.error(y_0_hat_batch)
                
                loss_vae = criterion(batch_y[:, -self.args.pred_len:, -target_num:], y_0_hat_batch[:, :, -target_num:], batch_y_mask[:, -self.args.pred_len:, -target_num:].to(self.device))
                loss_vae_all = loss_vae

                loss = self.args.k_cond * loss_vae_all
                loss = loss.detach().cpu()

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss


    def train_cond(self, setting):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        path2 = os.path.join(path, 'best_cond_model_dir/')

        if not os.path.exists(path):
            os.makedirs(path)

        if not os.path.exists(path2):
            os.makedirs(path2)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True, log=self.logger)

        model_optim = self._select_optimizer()

        criterion = self._select_criterion_mask()
        criterion_val = self._select_criterion_mask()
        target_num = self.args.target_num
        for epoch in range(self.args.train_epochs):
            # Training the diffusion part
            epoch_time = time.time()

            iter_count = 0
            train_loss = []
            train_vae_loss = []
            self.model.train()
            self.cond_pred_model.train()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x_mask = batch_x_mask.to(self.device)
                batch_x = batch_x * batch_x_mask
                if self.args.pseudo_len > 0:
                    batch_x[:, -self.args.pseudo_len:, -target_num:] = 0

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)
                batch_y_mask = batch_y_mask.to(self.device)
                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_window:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                
                y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, 
                                                     batch_y_mark, batch_x_mask, 
                                                     mode=self.args.mode, target=target_num)
                
                loss_vae = criterion(batch_y[:, -self.args.pred_len:, -target_num:], y_0_hat_batch[:, :, -target_num:], batch_y_mask[:, -self.args.pred_len:, -target_num:])
                #loss_vae = (batch_y[:, -self.args.pred_len:, -1:]- y_0_hat_batch[:, :, -1:]).square().mean()
                loss = loss_vae
               
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    self.logger.info("\titers: {0}, epoch: {1} | loss: {2:.7f} vae loss {3:.7f}".format(
                        i + 1, epoch + 1, loss.item(), loss_vae.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    self.logger.info('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
                    
                loss.backward()
                model_optim.step()

              
            self.logger.info("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali_cond(vali_data, vali_loader, criterion_val)
            test_loss = self.vali_cond(test_data, test_loader, criterion_val)
            self.logger.info(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f}  Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            
            early_stopping(vali_loss, test_loss, self.model, self.cond_pred_model, path)
            sys.stdout.flush()
            adjust_learning_rate(optimizer=model_optim, epoch=epoch, args=self.args, log=self.logger )

            if early_stopping.early_stop:
                self.logger.info("Early stopping")
                break
            
        return 


    def test_cond(self, setting, test):
        test_data, test_loader = self._get_data(flag='test')
        train_data, train_loader = self._get_data(flag='train')
        if test:
            self.logger.info('loading cond model')
            self.cond_pred_model.load_state_dict(
                torch.load(os.path.join('./checkpoints/' + setting, 'cond_checkpoint.pth'), map_location=self.device))

        preds = []
        masks = []
        trues = []
        self.cond_pred_model.eval()


        # features_in_hook = []
        # features_out_hook = []

        # def hook(module, fea_in, fea_out):
        #     features_in_hook.append(fea_in[0])
        #     features_out_hook.append(fea_out)
        #     return None
        # self.cond_pred_model.en_embedding.boundarynet.register_forward_hook(hook=hook)
        
        target_num = self.args.target_num
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(test_loader):   
                    batch_x = batch_x.float().to(self.device)
                    batch_y = batch_y.float().to(self.device)
                    batch_x = batch_x * batch_x_mask.to(self.device)
                    if self.args.pseudo_len > 0:
                        batch_x[:, -self.args.pseudo_len:, -1:] = 0

                    batch_x_mark = batch_x_mark.float().to(self.device)
                    batch_y_mark = batch_y_mark.float().to(self.device)
                    batch_y_mask = batch_y_mask.to(self.device)
                    # decoder input
                    dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_window:, :]).float()
                    dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                    
                    y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                    
                    mask = batch_y_mask[:, -self.args.pred_len:, -target_num:].detach().cpu().numpy()
                    true = batch_y[:, -self.args.pred_len:, -target_num:].detach().cpu().numpy()
                    pred = y_0_hat_batch[:, :, -target_num:].detach().cpu().numpy()

                    masks.append(mask)
                    trues.append(true)
                    preds.append(pred)

            masks = np.concat(masks, axis=0)
            trues = np.concat(trues, axis=0)
            preds = np.concat(preds, axis=0)

            mse = np.square((trues-preds)*masks).mean() / masks.mean()
            mae = (np.abs(trues-preds)*masks).mean() / masks.mean()
            self.logger.info("Test MSE: {0:.7f}, MAE: {1:.7f}".format(mse, mae))
            
        return 

    def test(self, setting, test=0):
        
        #####################################################################################################
        ########################## local functions within the class function scope ##########################

        def store_gen_y_at_step_t(config, config_diff, idx, y_tile_seq, test_batch_size,):
            """
            Store generated y from a mini-batch to the array of corresponding time step.
            """
            current_t = self.model.num_timesteps - idx
            gen_y = y_tile_seq[idx].reshape(test_batch_size,
                                            int(config_diff.testing.n_z_samples / config_diff.testing.n_z_samples_depart),
                                            (config.label_len + config.pred_len),
                                            config.c_out_diff).cpu().numpy()
            if len(gen_y_by_batch_list[current_t]) == 0:
                gen_y_by_batch_list[current_t] = gen_y
            else:
                gen_y_by_batch_list[current_t] = np.concatenate([gen_y_by_batch_list[current_t], gen_y], axis=0)
            return gen_y

        def compute_true_coverage_by_gen_QI(config, dataset_object, all_true_y, all_generated_y):
            n_bins = config.testing.n_bins
            quantile_list = np.arange(n_bins + 1) * (100 / n_bins)
            y_pred_quantiles = np.percentile(all_generated_y.squeeze(), q=quantile_list, axis=1)
            y_true = all_true_y.T
            quantile_membership_array = ((y_true - y_pred_quantiles) > 0).astype(int)
            y_true_quantile_membership = quantile_membership_array.sum(axis=0)
            y_true_quantile_bin_count = np.array(
                [(y_true_quantile_membership == v).sum() for v in np.arange(n_bins + 2)])

            y_true_quantile_bin_count[1] += y_true_quantile_bin_count[0]
            y_true_quantile_bin_count[-2] += y_true_quantile_bin_count[-1]
            y_true_quantile_bin_count_ = y_true_quantile_bin_count[1:-1]
            y_true_ratio_by_bin = y_true_quantile_bin_count_ / dataset_object
            # assert np.abs(
            #     np.sum(y_true_ratio_by_bin) - 1) < 1e-10, "Sum of quantile coverage ratios shall be 1!"
            qice_coverage_ratio = np.absolute(np.ones(n_bins) / n_bins - y_true_ratio_by_bin).mean()
            return y_true_ratio_by_bin, qice_coverage_ratio, y_true

        def compute_PICP(config, y_true, all_gen_y, return_CI=False):
            """
            Another coverage metric.
            """
            low, high = config.testing.PICP_range
            CI_y_pred = np.percentile(all_gen_y.squeeze(), q=[low, high], axis=1)
            y_in_range = (y_true >= CI_y_pred[0]) & (y_true <= CI_y_pred[1])
            coverage = y_in_range.mean()
            if return_CI:
                return coverage, CI_y_pred, low, high
            else:
                return coverage, low, high

        test_data, test_loader = self._get_data(flag='test')
        if test:
            self.logger.info('loading model')
            self.model.load_state_dict(
                torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'), map_location=self.device))
            self.cond_pred_model.load_state_dict(
                torch.load(os.path.join('./checkpoints/' + setting, 'cond_checkpoint.pth'), map_location=self.device))

        preds = []
        trues = []
        masks = []
        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        minibatch_sample_start = time.time()

        self.model.eval()
        self.cond_pred_model.eval()
        target_num = self.args.target_num
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_mask, batch_y_mask) in enumerate(test_loader):
                gen_y_by_batch_list = [[] for _ in range(self.model.num_timesteps + 1)]
                y_se_by_batch_list = [[] for _ in range(self.model.num_timesteps + 1)]

                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x = batch_x * batch_x_mask.to(self.device)
                if self.args.pseudo_len > 0:

                    batch_x[:, -self.args.pseudo_len:, -target_num:] = 0

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

                    else:

                        y_0_hat_batch = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                        y_0_hat_batch = y_0_hat_batch[:, :, -target_num:]
                        repeat_n = int(
                            self.model.diffusion_config.testing.n_z_samples / self.model.diffusion_config.testing.n_z_samples_depart)
                        y_0_hat_tile = y_0_hat_batch.repeat(repeat_n, 1, 1, 1)
                        y_0_hat_tile = y_0_hat_tile.transpose(0, 1).flatten(0, 1).to(self.device)
                        y_T_mean_tile = y_0_hat_tile
                        x_tile = batch_x.repeat(repeat_n, 1, 1, 1)
                        x_tile = x_tile.transpose(0, 1).flatten(0, 1).to(self.device)

                        x_mark_tile = batch_x_mark.repeat(repeat_n, 1, 1, 1)
                        x_mark_tile = x_mark_tile.transpose(0, 1).flatten(0, 1).to(self.device)

                        gen_y_box = []
                        for _ in range(self.model.diffusion_config.testing.n_z_samples_depart):
                            for _ in range(self.model.diffusion_config.testing.n_z_samples_depart):
                                y_tile_seq = p_sample_loop(self.model, x_tile, x_mark_tile, y_0_hat_tile, y_T_mean_tile,
                                                           self.model.num_timesteps,
                                                           self.model.alphas, self.model.one_minus_alphas_bar_sqrt,
                                                           )

                            gen_y = store_gen_y_at_step_t(config=self.model.args,
                                                          config_diff=self.model.diffusion_config,
                                                          idx=self.model.num_timesteps, y_tile_seq=y_tile_seq, test_batch_size=batch_x.shape[0] )
                            gen_y_box.append(gen_y)
                        outputs = np.concatenate(gen_y_box, axis=1)

                        f_dim = target_num
                        outputs = outputs[:, :, -self.args.pred_len:, f_dim:]
                        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                        batch_y = batch_y.detach().cpu().numpy()
                        mask = batch_y_mask[:, -self.args.pred_len:, f_dim:].to(self.device).detach().cpu().numpy()

                        pred = outputs
                        true = batch_y

                        preds.append(pred)
                        trues.append(true)
                        masks.append(mask)
                if i % 5 == 0 and i != 0:
                    self.logger.info('Testing: %d/%d cost time: %f min' % (
                        i, len(test_loader), (time.time() - minibatch_sample_start) / 60))
                    minibatch_sample_start = time.time()
                    sys.stdout.flush()

        # [num_batch*batch_size, n_sample, L, N]
        # [num_batch*batch_size, L, N]
        preds = np.concat(preds, axis=0)
        trues = np.concat(trues, axis=0)
        masks = np.concat(masks, axis=0)
        # print(f"preds shape{preds.shape}")
        # print(f"trues shape{trues.shape}")
        preds_save = np.array(preds)
        trues_save = np.array(trues)
        masks_save = np.array(masks)

        preds_ns = np.array(preds).mean(axis=1)


        preds_ns = preds_ns.reshape(-1, preds_ns.shape[-2], preds_ns.shape[-1])
        trues_ns = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        masks_ns = masks.reshape(-1, trues.shape[-2], trues.shape[-1])
   

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        
        
        mae, mse = metric_mask(preds_ns, trues_ns, masks_ns)
        self.logger.info('NT metrc: mse:{:.4f}, mae:{:.4f} '.format(mse, mae) )

        # masks = masks.reshape(-1)

        # # ->[*, n_sample]
        # preds = preds.reshape(-1, preds.shape[-3], preds.shape[-2] * preds.shape[-1])
        # preds = preds.transpose(0, 2, 1)
        # preds = preds.reshape(-1, preds.shape[-1])
        # preds = preds[masks==1]
        # # ->[*, 1]
        # trues = trues.reshape(-1, 1, trues.shape[-2] * trues.shape[-1])
        # trues = trues.transpose(0, 2, 1)
        # trues = trues.reshape(-1, trues.shape[-1])
        # trues = trues[masks==1]
        


        # y_true_ratio_by_bin, qice_coverage_ratio, y_true = compute_true_coverage_by_gen_QI(
        #     config=self.model.diffusion_config, dataset_object=preds.shape[0],
        #     all_true_y=trues, all_generated_y=preds, )

        # coverage, _, _ = compute_PICP(config=self.model.diffusion_config, y_true=y_true, all_gen_y=preds)

        # self.logger.info('CARD metrc: QICE:{:.4f}%, PICP:{:.4f}%'.format(qice_coverage_ratio * 100, coverage * 100))

        # # ->[num_batch*batch_size, n_sample, L, N]
        # pred = preds_save.reshape(-1, preds_save.shape[-3], preds_save.shape[-2], preds_save.shape[-1])
        # # ->[num_batch*batch_size, L, N]
        # true = trues_save.reshape(-1, trues_save.shape[-2], trues_save.shape[-1])
        # masks = masks_save.reshape(-1, trues_save.shape[-2], trues_save.shape[-1])
        

        # all_res_get = []
        # i=-1
        # mask = masks[:, :, i]
        # mask = mask.reshape(-1)

        # p_in = pred[:, :, :, i]
        # p_in = p_in.transpose(0, 2, 1)
        # # ->[num_batch*batch_size*L, n_sample]
        # p_in = p_in.reshape(-1, p_in.shape[-1])
        # p_in = np.array(p_in[mask==1])

        # t_in = true[:, :, i]
        # t_in = t_in.reshape(-1)
        # t_in = np.array(t_in[mask==1])

        # # ->[num_batch*batch_size*L]
        # print(p_in.shape, t_in.shape)
        # all_res_get.append(ccc(i, p_in, t_in))
        # CRPS_0 = np.mean(all_res_get, axis=0).mean()
        # self.logger.info('CRPS {}'.format(CRPS_0))

        return
    



