
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from choose_optimizer import *
from matplotlib import pyplot as plt 
import matplotlib
import os 
from utils import *
import copy 
from mpl_toolkits.axes_grid1 import make_axes_locatable
import math 
import random 
from torch.utils.tensorboard import SummaryWriter 
from tqdm import tqdm 
from pinn_node import FISRwithODE



class PIDO():
    def __init__(self, args, odeint_func, repeat, logger, device, x_all, y_all, t_train, t_test, t_eval, train_theta, Exact_train, name_list_train, eval_theta, Exact_eval, name_list_eval, lr, net):
        self.args = args 
        self.repeat = repeat
        self.logger = logger 
        self.device = device 
        self.x_all = x_all 
        self.y_all = y_all 
        self.t_all = t_train 

        self.args.whole_batch_init = self.args.batch_init 
        self.args.batch_init = self.args.batch_init // self.args.world_size
        
        self.x_test = torch.tensor(x_all).float().to(device)
        self.y_test = torch.tensor(y_all).float().to(device)
        self.t_train = torch.tensor(t_train).float().to(device)
        self.t_test = torch.tensor(t_test).float().to(device)
        self.t_eval = torch.tensor(t_eval).float().to(device)
     
        self.Exact_train = Exact_train 
        self.name_list_train = name_list_train 
        self.Exact_eval = Exact_eval 
        self.name_list_eval = name_list_eval 
        self.train_theta = torch.tensor(train_theta).float().to(device) 
        self.eval_theta = torch.tensor(eval_theta).float().to(device) 
        
        self.u0_step = self.args.fix_points_x // 64
        self.init_coeff = 1.0 
        self.t_range_coeff = 1.0 

        self.param_dim = self.args.param_dim

        self.x_range = torch.tensor(self.args.x_range).float().to(device)
        self.t_range = torch.tensor(self.args.t_range).float().to(device) 
        self.theta_range = torch.tensor(self.args.theta_range).float().to(device) 

        self.use_auxiliary = args.use_auxiliary
      
        self.random_sample = args.random_sample
        if self.args.rank == 0:
            self.writer = SummaryWriter(log_dir=os.path.join(args.work_dir, f'run_{repeat}')) 


        self.L_f = args.L_f 
        self.L_u = args.L_u 
        self.L_b = args.L_b
        self.L_pl = args.L_pl

        self.lr = lr

        self.net = net
        
        if self.net == 'FISRwithODE':
            self.dnn = FISRwithODE(args, odeint_func).to(device)
            self.dnn_auxiliary = FISRwithODE(args, odeint_func).to(device) 
        else:
            raise NotImplementedError
        
        self.dnn_without_ddp = self.dnn 
        if self.args.launcher == 'slurm' or self.args.launcher == 'pytorch':
            self.dnn = torch.nn.parallel.DistributedDataParallel(self.dnn, device_ids=[self.args.gpu_id])
            self.dnn_without_ddp = self.dnn.module 

        for name,param in self.dnn_auxiliary.named_parameters():
            param.requires_grad = False 

        self.logger.info(self.dnn_without_ddp)
       
        self.iter = 0
        self.epsilon = self.args.epsilon_t # 0.99 
        self.epsilon_t = self.args.epsilon_t # 0.99 


    def train_adam(self, adam_lr, epoch):
        params_net = []
        params_ode = []
        params_coeff = []
        params_pl = []
        for name, param in self.dnn.named_parameters():
            if 'ode' in name:
                params_ode.append(param)
            elif 'z0' in name:
                params_coeff.append(param)
            elif 'z_consistency' in name:
                params_pl.append(param)
            else:
                params_net.append(param)
        self.optimizer = choose_optimizer('AdamW', [{'params': params_net}], adam_lr)
        self.z0_optimizer = choose_optimizer('AdamW', [{'params': params_coeff}], adam_lr*self.args.z0_ratio)
        self.ode_optimizer = choose_optimizer('AdamW', [{'params': params_ode}], adam_lr*self.args.ode_ratio)
        self.z_pl_optimizer = choose_optimizer('AdamW', [{'params': params_pl}], adam_lr*self.args.z0_ratio)

        # index of learnable embeddings 
        self.train_init_idx = torch.arange(self.train_theta.shape[0]) 
        start_epoch = 0
        self.start_epoch = start_epoch

        self.dnn.train() 
        self.validating = False 
        self.logger.info('>>>>>> Adam optimizer')
        
        for epoch_n in range(start_epoch, epoch):
            self.epoch_n = epoch_n
            g = torch.Generator()
            g.manual_seed(epoch_n)
            indices_init = torch.randperm(self.args.num_init_cond, generator=g).tolist()
           
            for i_iter in range(self.args.num_init_cond // self.args.whole_batch_init):
                theta_range = self.theta_range
                # sample_data
                self.t_f, self.x_f, self.y_f, self.theta_f, self.init_gt_batch, self.batch_init_cond_idx  = self.sample_data(i_iter=i_iter, indices_init=indices_init, theta_range=theta_range, x_range=self.x_range, t_range=self.t_range_coeff*self.t_range, repeat=self.args.sample_repeat)
                # forward and update
                loss = self.loss_pinn(verbose=True,step=True)
           

            if (epoch_n+1) % self.args.save_freq == 0 and self.args.rank == 0:
                name = f'checkpoint.pth.tar' if self.args.checkpoint_name else f'adam_{epoch_n+1}.pth.tar'
            
                torch.save({'state_dict':self.dnn_without_ddp.state_dict(),
                            'optimizer': self.optimizer.state_dict(),
                            'z0_optimizer': self.z0_optimizer.state_dict(),
                            'z_pl_optimizer': self.z_pl_optimizer.state_dict(),
                            'ode_optimizer': self.ode_optimizer.state_dict(),
                            'epoch': epoch_n+1}, 
                            os.path.join(self.args.work_dir, f'results_{self.repeat}', 'models',name))
                
        return self.loss_pinn(evaluate=True)
    
    def sample_data(self, i_iter, indices_init, theta_range, x_range, t_range, repeat=False):
        if self.iter == self.start_epoch:
            # avoid repeated operation
            T = torch.linspace(t_range[0], t_range[1], self.args.fix_points_t+1, dtype=torch.float32, device=self.device, requires_grad=False)[:self.args.fix_points_t] 
            X = torch.linspace(x_range[0], x_range[1], self.args.fix_points_x+1, dtype=torch.float32, device=self.device, requires_grad=False)

            self.logger.info(f'training t: {T}')
            self.logger.info(f'training x: {X}')

            self.fix_sample_T = T.reshape(-1, 1, 1, 1, 1, 1).repeat(1, self.args.batch_init, self.args.batch_theta, self.args.fix_points_x+1, self.args.fix_points_x+1, 1)
            self.fix_sample_X = X.reshape(1, 1, 1, -1, 1, 1).repeat(self.args.fix_points_t, self.args.batch_init, self.args.batch_theta, 1, self.args.fix_points_x+1, 1)
            self.fix_sample_Y = X.reshape(1, 1, 1, 1, -1, 1).repeat(self.args.fix_points_t, self.args.batch_init, self.args.batch_theta, self.args.fix_points_x+1, 1, 1)
           
        init_cond = indices_init[i_iter*self.args.whole_batch_init+self.args.rank:(i_iter+1)*self.args.whole_batch_init:self.args.world_size]
       
        # set batch_theta==1, because different initial conditions already have their corresponding thetas
        assert self.args.batch_theta == 1
        # PDE coefficients for each inital condition
        theta = self.train_theta[init_cond]
        fix_smaple_theta = theta.reshape(1,self.args.batch_init,self.args.batch_theta,1,1,-1).repeat(self.args.fix_points_t,1,1,1,1,1)
        init_gt_batch = torch.cat([self.loss_data_exact[theta_idx_i:theta_idx_i+1,0:1] for theta_idx_i in init_cond], dim=0)    # (#init,1,H,W,$)
        sample_T = self.fix_sample_T.clone()
        return sample_T, self.fix_sample_X, self.fix_sample_Y, fix_smaple_theta, init_gt_batch, init_cond 

    def loss_pinn(self, verbose=False, step=False, evaluate=False):
        """ train step. """
       
        self.optimizer.zero_grad()
        self.ode_optimizer.zero_grad()
        self.z0_optimizer.zero_grad()
        self.z_pl_optimizer.zero_grad()
        self.dnn_auxiliary.load_state_dict(self.dnn_without_ddp.state_dict())
       
        # learnable embedding for initial conditions; during training, the auto-decoding is approximated with one-step gradient-decent
        z0_batch = self.dnn_without_ddp.z0[self.batch_init_cond_idx].unsqueeze(1).contiguous() # (batch_init, 1, 1)
        
        # unroll from z0 to get z_t (z_pred)
        z_pred = self.scheduling(self.dnn_without_ddp, z0_batch.repeat(1, self.args.batch_theta, 1), self.theta_f[0,:,:,0,0,:], self.t_f[:,0,0,0,0,:].reshape(-1))  # (t, batch_init, batch_theta, codesize)
        # get derivative of z w.r.t time, used in PDE loss
        z_pred_t = self.dnn_without_ddp.get_dyn_grad(self.t_f[:,0,0,0,0,:].reshape(-1), z_pred, self.theta_f[:,:,:,0,0,:])
        # get all prediction 
        w_pred, f_pred, psi_pred, _, _, _, _ = self.net_f(self.x_f, self.y_f, z_pred, z_pred_t, self.theta_f, detach=self.args.detach_f)
        
        # loss IC: update embeddings of inital conditions (z0_batch) and decoder
        w_pred_square = (self.init_coeff * self.init_gt_batch[...,0:1] - w_pred[0, :, :, :-1:self.u0_step, :-1:self.u0_step]) ** 2  
        psi_pred_square = (self.init_coeff * self.init_gt_batch[...,1:2] - psi_pred[0, :, :, :-1:self.u0_step, :-1:self.u0_step]) ** 2  
        loss_u_avg = torch.mean(w_pred_square) + torch.mean(psi_pred_square) #+ torch.mean(psi_pred[0, :-1:self.u0_step, :-1:self.u0_step]) ** 2

        # loss PDE: update decoder and dynamics model
        f_pred_square = f_pred ** 2
        loss_f_max = torch.max(f_pred_square)
        loss_f_avg = torch.mean(f_pred_square)

        # loss BC: update decoder and dynamics model 
        loss_w_b = torch.mean((w_pred[:,:,:,0,:,:] - w_pred[:,:,:,-1,:,:]) ** 2) + torch.mean((w_pred[:,:,:,:,0,:] - w_pred[:,:,:,:,-1,:]) ** 2)
        loss_psi_b = torch.mean((psi_pred[:,:,:,0,:,:] - psi_pred[:,:,:,-1,:,:]) ** 2) + torch.mean((psi_pred[:,:,:,:,0,:] - psi_pred[:,:,:,:,-1,:]) ** 2)
        loss_b_avg = loss_w_b + loss_psi_b 
      
        # get the embedding of predicted solutions for consistency regularziation 
        z_pl = self.dnn_without_ddp.z_consistency[self.batch_init_cond_idx,:self.x_f.shape[0]].permute(1,0,2,3)
        # in the following computing, freeze the weight of decoder with use_auxiliary=True because we only update the embedding z_pl 
        psi_and_w_pl = self.net_u(self.x_f, self.y_f, z_pl.unsqueeze(3).unsqueeze(3).contiguous().repeat(1,1,1,self.x_f.shape[3],self.x_f.shape[4],1), use_auxiliary=True, detach=False)
        psi_pl = psi_and_w_pl[...,0:1]
        w_pl = psi_and_w_pl[...,1:2]
        # update z_pl 
        loss_z_pl = torch.mean((w_pred.detach()-w_pl)**2) + torch.mean((psi_pred.detach()-psi_pl)**2)
        # use z_pl as reference point to regularize z_pred 
        loss_consistency = torch.mean((z_pred - z_pl.detach())**2)
      
        # smoothness regularization
        loss_kinetic_penalty = torch.mean(z_pred_t**2, dim=-1).mean()
        jacobian_z = torch.autograd.grad(
                z_pred_t, z_pred,
                grad_outputs=torch.randn_like(z_pred_t),
                retain_graph=True,
                create_graph=True
            )[0]
        loss_jacobian_penalty = torch.mean(jacobian_z**2, dim=-1).mean()
        
        loss = self.L_u*loss_u_avg + self.L_b*loss_b_avg + self.L_f*loss_f_avg + self.L_pl*loss_z_pl + self.L_pl*loss_consistency + self.args.kinetic_penalty*loss_kinetic_penalty + self.args.jacobian_penalty*loss_jacobian_penalty
        if evaluate:
            return loss.detach().cpu().numpy(), loss_u_avg.detach().cpu().numpy(), loss_b_avg.detach().cpu().numpy(), loss_f_avg.detach().cpu().numpy() 

        if step:
            self.optimizer.step()
            self.ode_optimizer.step()
            self.z0_optimizer.step()
            self.z_pl_optimizer.step()
       
        if verbose:
          
            if (self.iter ==0 or (self.iter+1) % self.args.print_freq == 0) and self.args.rank == 0:
        
                self.logger.info(
                    '[epoch %d/iter %d], loss: %.5e, loss_u_avg: %.5e, L_u: %.5e, loss_b_avg: %.5e, L_b: %.5e, loss_f_avg: %.5e, loss_f_max: %.5e, L_f: %.5e, loss_z_pl: %.5e, loss_consistency: %.5e, L_pl: %.5e, loss_kinetic_penalty: %.5e, L_kinetic_penalty: %.5e, loss_jacobian_penalty: %.5e, L_jacobian_penalty: %.5e, lr: %.5e, ode lr: %.5e, z0 lr: %.5e' % (self.epoch_n, self.iter, loss.item(), loss_u_avg.item(), self.L_u, loss_b_avg.item(), self.L_b, loss_f_avg.item(), loss_f_max.item(), self.L_f, loss_z_pl.item(), loss_consistency.item(), self.L_pl, loss_kinetic_penalty.item(), self.args.kinetic_penalty, loss_jacobian_penalty.item(), self.args.jacobian_penalty, self.optimizer.param_groups[0]['lr'], self.ode_optimizer.param_groups[0]['lr'], self.z0_optimizer.param_groups[0]['lr'])
                )
             
                self.writer.add_scalars(f'loss_all', {'loss_all':loss.item()}, self.iter)
                self.writer.add_scalars(f'loss_u', {'loss_u':loss_u_avg.item()}, self.iter)
                self.writer.add_scalars(f'loss_b', {'loss_b':loss_b_avg.item()}, self.iter)
                self.writer.add_scalars(f'loss_f', {'loss_f':loss_f_avg.item()}, self.iter)
                self.writer.add_scalars(f'loss_f_max', {'loss_f_max':loss_f_max.item()}, self.iter)
             
            if (self.iter+1) % self.args.valid_freq == 0:
                self.logger.info('>>> Evaluation on train theta')
                num_per_visc = self.args.num_init_cond // len(self.args.visc_train_set)
                for valid_visc_idx, valid_visc in enumerate(self.args.visc_train_set):
                    self.validate(self.train_theta[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], self.Exact_train[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], self.name_list_train[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], z0_index=self.train_init_idx[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], val_name=f'Train_{valid_visc_idx}_{valid_visc}')
            
                if self.args.eval_eval_set:
                    self.logger.info('>>> Evaluation on eval theta')
                    num_per_visc = self.args.num_init_cond_eval // len(self.args.visc_eval_set)
                    for valid_visc_idx, valid_visc in enumerate(self.args.visc_eval_set):
                        self.validate(self.eval_theta[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], self.Exact_eval[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], self.name_list_eval[valid_visc_idx*num_per_visc:(valid_visc_idx+1)*num_per_visc], z0_index=None, val_name=f'Eval_{valid_visc_idx}_{valid_visc}')
            self.iter += 1
        return loss
    
    def net_u(self, x, y, z, use_auxiliary=False, detach=False):
        """ (x,y,z) --> (psi, w). see details in net_f() """
        self.dnn_without_ddp.detach = detach
        input_xt = torch.cat([x, y], dim=-1)
        if use_auxiliary:
            u = self.dnn_auxiliary(input_xt, z)
        else:
            u = self.dnn(input_xt, z)
        
        return u

    def net_f(self, x, y, z, z_t, theta, use_auxiliary=False, detach=False):
        """ PDE loss
        for ns equation: (u_x/t: the first order derivative of x/t, u_xx: the second order derivative)
        
        (1) w_t + u * w_x + v * w_y - theta * (w_xx + w_yy) = f
        (2) w - v_x + u_y = 0
        (3) u_x + v_y = 0
        
        we parameterize the network as nn(x,y,z)=(psi, w), where psi denotes the streamfunction.
        Let u=psi_y and v=-psi_u to enforce eq.(3), because we have u_x = psi_xy = psi_yx = -v_y 
        Thereby, we have two equation to satisy:
        (1) w_t + psi_y * w_x + (-psi_x) * w_y - theta * (w_xx + w_yy) - f = 0
        (2) w + psi_xx + psi_yy = 0

        """

        x.requires_grad = True 
        y.requires_grad = True 

        z = z.unsqueeze(3).unsqueeze(3).contiguous().repeat(1,1,1,x.shape[3],x.shape[4],1)
        z_t = z_t.unsqueeze(3).unsqueeze(3).contiguous()
        u_t_weight = 1
        psi_and_w = self.net_u(x, y, z, use_auxiliary=use_auxiliary, detach=detach)
        psi = psi_and_w[...,0:1]
        w = psi_and_w[...,1:2]
     
        u = torch.autograd.grad(
            psi, y,
            grad_outputs=torch.ones_like(psi),
            retain_graph=True,
            create_graph=True
        )[0]
        v = torch.autograd.grad(
            psi, x,
            grad_outputs=torch.ones_like(psi),
            retain_graph=True,
            create_graph=True
        )[0]
     
        psi_xx = torch.autograd.grad(
            v, x,
            grad_outputs=torch.ones_like(v),
            retain_graph=True,
            create_graph=True
        )[0]

        psi_yy = torch.autograd.grad(
            u, y,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]

        v = -1.0 * v 

        w_z = torch.autograd.grad(
            w, z,
            grad_outputs=torch.ones_like(w),
            retain_graph=True,
            create_graph=True
        )[0]

        w_t = torch.sum(w_z * z_t, dim=-1, keepdim=True)

        w_x = torch.autograd.grad(
            w, x,
            grad_outputs=torch.ones_like(w),
            retain_graph=True,
            create_graph=True
        )[0]

        w_xx = torch.autograd.grad(
            w_x, x,
            grad_outputs=torch.ones_like(w_x),
            retain_graph=True,
            create_graph=True
        )[0]

        w_y = torch.autograd.grad(
            w, y,
            grad_outputs=torch.ones_like(w),
            retain_graph=True,
            create_graph=True
        )[0]

        w_yy = torch.autograd.grad(
            w_y, y,
            grad_outputs=torch.ones_like(w_y),
            retain_graph=True,
            create_graph=True
        )[0]

        f1 = u_t_weight * w_t + u * w_x + v * w_y - theta * (w_xx + w_yy) 
        if self.args.de_force:
            G = 0 
        else:
            G = 0.1 * (torch.sin(2 * np.pi * (x + y)) + torch.cos(2 * np.pi * (x + y)))

        f1 = f1 - G.detach() 
        f2 = w + (psi_xx + psi_yy)
        f = torch.cat([f1, f2], dim=-1)
        return w, f, psi, u, v, w_x, w_y

    def scheduling(self, model, true_codes, theta, t):
        codes = model.get_dyn(t, true_codes, theta)
        return codes

    def validate(self, theta_list, exact_list, name_list, z0_index=None, val_name='Train'):
        return self.validate_batch(theta_list, exact_list, name_list, z0_index=z0_index, val_name=val_name)
        
    def validate_batch(self, theta_list, exact_list, name_list, z0_index=None, val_name='Train'):
        self.validating = True 
        self.dnn_auxiliary.load_state_dict(self.dnn_without_ddp.state_dict())
        self.logger.info('>>> Evaluation begin')
        error_list = []
        loss_f_list = [] 

        # auto decoding
        z0_batch = self.auto_decode(self.x_test[0:1], self.y_test[0:1], exact_list, z0_index, theta_list)
        self.dnn_auxiliary.eval()
        self.dnn.eval()

        if 'Train' in val_name:
            t_test = self.t_train 
        else:
            t_test = self.t_eval 
        
        w_pred = []
        z_pred_all = []
        eval_total = exact_list.shape[0]
        
        for test_batch_id in range(math.ceil(1.0*exact_list.shape[0]/self.args.valid_batch)):
            z0 = z0_batch[test_batch_id*self.args.valid_batch:(test_batch_id+1)*self.args.valid_batch]
            z0_batch_size = z0.shape[0]
            z0 = z0.reshape(z0_batch_size, 1, -1)
            theta_star = theta_list[test_batch_id*self.args.valid_batch:(test_batch_id+1)*self.args.valid_batch].reshape(1,z0_batch_size,1,1,1,-1).repeat(t_test.shape[0],1,1,t_test.shape[3], t_test.shape[4], 1)
           
            z_pred = self.scheduling(self.dnn_auxiliary, z0, theta_star[0,:,:,0,0,:], t_test[:,0,0,0,0,:].reshape(-1))
            with torch.no_grad():
                w_pred_i, psi_pred_i = self.predict_batch(self.x_test[:t_test.shape[0]].repeat(1,z0_batch_size,1,1,1,1), self.y_test[:t_test.shape[0]].repeat(1,z0_batch_size,1,1,1,1), z_pred, use_auxiliary=True)
                w_pred.append(w_pred_i[:,:,0].cpu().numpy())
                z_pred_all.append(z_pred.cpu().numpy())
                del w_pred_i, psi_pred_i, z_pred 
        w_pred = np.concatenate(w_pred, axis=1).transpose(1,0,2,3,4)
        z_pred_all = np.concatenate(z_pred_all, axis=1)
        Exact_w = exact_list[..., 0:1]

        # error in-t: [0, tmax_ind-1], error out-t: [tmax_ind:]
        error_u_relative = (np.linalg.norm((Exact_w-w_pred).reshape(eval_total, -1), 2, axis=-1)/np.linalg.norm(Exact_w.reshape(eval_total, -1), 2, axis=-1)).mean()
        error_u_relative_in = (np.linalg.norm((Exact_w[:, :self.args.tmax_ind]-w_pred[:, :self.args.tmax_ind]).reshape(eval_total, -1), 2, axis=-1)/np.linalg.norm(Exact_w[:, :self.args.tmax_ind].reshape(eval_total, -1), 2, axis=-1)).mean()
        error_u_relative_out = np.array([np.mean(np.linalg.norm((Exact_w[:,self.args.tmax_ind:]-w_pred[:,self.args.tmax_ind:]).reshape(eval_total, -1), 2, axis=-1)/np.linalg.norm(Exact_w[:,self.args.tmax_ind:].reshape(eval_total,-1), 2, axis=-1), axis=0)])
            
        
        if self.args.rank == 0:
            error_info = f'>>> Evaluation {val_name} end: average error total: {error_u_relative}. in: {error_u_relative_in}. out: {error_u_relative_out}'
            self.logger.info(error_info)

        self.dnn_auxiliary.train()
        self.dnn.train()

        self.validating = False  
        return error_list, loss_f_list 

    def auto_decode(self, x, y, exact_list, z0_index, theta_list, save_best=True):
       
        if z0_index is not None:
            # for initial conditions in the training set, we fetch the embeddings learned during training
            z0_batch = self.dnn_auxiliary.z0[z0_index] 
            # import pdb; pdb.set_trace()
            return z0_batch.detach().clone()
        
        else:
            # for initial conditions in the eval set, we use auto-decoding
            x_batch = x.repeat(1,exact_list.shape[0],1,1,1,1)
            y_batch = y.repeat(1,exact_list.shape[0],1,1,1,1)
            init_gt_batch = torch.tensor(exact_list[:, 0:1]).float().to(self.device)   
           
            loss_min_test = 1e30
            z0_batch = nn.parameter.Parameter(torch.zeros(exact_list.shape[0], self.args.ode_code_size, dtype=torch.float32, device=self.device))
         
            z0_optimizer = choose_optimizer('AdamW', [{'params': [z0_batch]}], self.args.adam_lr, weight_decay=self.args.weight_decay)
          
            for step_idx in range(self.args.eval_decode_steps):
               
                z0_input = z0_batch.reshape(1,exact_list.shape[0],1,1,1,-1).repeat(1,1,1,x.shape[3], x.shape[4],1)
                psi_and_w = self.net_u(x_batch, y_batch, z0_input, use_auxiliary=True)
                psi = psi_and_w[...,0:1]
                w = psi_and_w[...,1:2]

                loss_u = torch.mean((w[0]-init_gt_batch[...,0:1])**2) + torch.mean((psi[0]-init_gt_batch[...,1:2])**2)
 
                if loss_u < loss_min_test and save_best:
                    loss_min_test = loss_u
                    best_z0_batch = z0_batch.detach().clone()
        
                z0_optimizer.zero_grad(True)
                loss_u.backward()
                z0_optimizer.step()

                if step_idx + 1 == self.args.eval_decode_steps:
                    self.logger.info(f'Test decode error {step_idx}: {loss_u}, best: {loss_min_test}')

            if save_best:
                return best_z0_batch
            else:
                return z0_batch.detach().clone()

    def predict_batch(self, x_test, y_test, z_pred, use_auxiliary=False):
        w_pred = []
        psi_pred = []
        z_pred = z_pred.unsqueeze(3).unsqueeze(3).contiguous().repeat(1,1,1,x_test.shape[3],x_test.shape[4],1)
        eval_total_t = x_test.shape[0]

        for test_t_id in range(math.ceil(1.0*eval_total_t/self.args.valid_batch_t)):
            psi_and_w = self.net_u(x_test[test_t_id*self.args.valid_batch_t:(test_t_id+1)*self.args.valid_batch_t], y_test[test_t_id*self.args.valid_batch_t:(test_t_id+1)*self.args.valid_batch_t], z_pred[test_t_id*self.args.valid_batch_t:(test_t_id+1)*self.args.valid_batch_t], use_auxiliary=use_auxiliary)
            psi_pred_id = psi_and_w[...,0:1]
            w_pred_id = psi_and_w[...,1:2]
           
            w_pred.append(w_pred_id.detach())
           
            psi_pred.append(psi_pred_id.detach())
        
            del w_pred_id
            del psi_pred_id
        return torch.cat(w_pred, dim=0), torch.cat(psi_pred, dim=0)
      

    