import argparse
import math 
from net_ns2 import *
import numpy as np
import os
import random
import torch
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
import shelve 
import torchdiffeq 
import logging

torch.backends.cuda.matmul.allow_tf32 = False
torch.set_default_dtype(torch.float32)

parser = argparse.ArgumentParser(description='PIDO ns2 scenarios')

parser.add_argument('--launcher', default='', type=str)
parser.add_argument('--port', default=29051, type=int)
parser.add_argument('--data_path', type=str, default=None)
parser.add_argument('--data_path2', type=str, default=None)
parser.add_argument('--seed', type=int, default=0, help='Random initialization.')

parser.add_argument('--lr', type=float, default=1.0, help='Learning rate.')
parser.add_argument('--L_f', type=float, default=1.0, help='Multiplier on loss f.')
parser.add_argument('--L_u', type=float, default=1.0, help='Multiplier on loss u.')
parser.add_argument('--L_b', type=float, default=1.0, help='Multiplier on loss b.')
parser.add_argument('--L_pl', type=float, default=1.0, help='Multiplier on loss pl.')
parser.add_argument('--adam_lr', type=float, default=0.0, help='Learning rate for adam.')

parser.add_argument('--epoch', type=int, default=1000, help='Epoch for adam')
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--start_repeat', type=int, default=0)


parser.add_argument('--xgrid', type=int, default=256)
parser.add_argument('--tspan', type=int, default=100)
parser.add_argument('--xmin', type=float, default=0)
parser.add_argument('--xmax', type=float, default=1)
parser.add_argument('--tmin', type=float, default=0)
parser.add_argument('--tmax', type=float, default=2)

# network 
parser.add_argument('--net', type=str, default='FISRwithODE')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--save_freq', type=int, default=10000)
parser.add_argument('--checkpoint_name', action='store_true')
parser.add_argument('--init', action='store_true')
parser.add_argument('--gpu', default=False, action='store_true')
parser.add_argument('--work_dir', default='debug')
parser.add_argument('--sub_name', default='')
parser.add_argument('--print_freq', type=int, default=1000)
parser.add_argument('--valid_freq', type=int, default=10000)
parser.add_argument('--plot_freq', type=int, default=10000)
parser.add_argument('--exp_dir', default='')
parser.add_argument('--batch_theta', type=int, default=1)
parser.add_argument('--batch_init', type=int, default=20)
parser.add_argument('--batch_coords', type=int, default=6400)
parser.add_argument('--batch_f', type=int, default=6400)
parser.add_argument('--batch_u', type=int, default=512)
parser.add_argument('--batch_b', type=int, default=200)
parser.add_argument('--split', type=str, default='')

# ode & isr 
parser.add_argument('--ode_code_size', type=int, default=64)
parser.add_argument('--ode_hidden_size', type=int, default=64) 
parser.add_argument('--ode_depth', type=int, default=2) 
parser.add_argument('--ode_activation', type=str, default='sin') 
parser.add_argument('--ode_method', type=str, default='rk4')
parser.add_argument('--ode_adjoint', action='store_true')
parser.add_argument('--z0_ratio', type=float, default=1.0)
parser.add_argument('--ode_ratio', type=float, default=1.0)
parser.add_argument('--num_init_cond', type=int, default=256) 
parser.add_argument('--num_init_cond_eval', type=int, default=32) 
parser.add_argument('--eval_decode_steps', type=int, default=1000) 
parser.add_argument('--eval_eval_set', action='store_true')
parser.add_argument('--epsilon_t', type=float, default=0.0)
parser.add_argument('--step_size', type=float, default=0.0)
parser.add_argument('--valid_batch', type=int, default=32) 
parser.add_argument('--valid_batch_t', type=int, default=10) 
parser.add_argument('--visc', type=float, default=0.001) 
parser.add_argument('--tstart', type=float, default=0, help='Number of points in the xgrid.')
parser.add_argument('--visc_train_set', type=str, default='-') 
parser.add_argument('--visc_eval_set', type=str, default='-') 

args = parser.parse_args()

args.name = args.work_dir
logger.info(args.name)
logger.info(args)

# CUDA support
if args.gpu:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

if args.ode_adjoint:
    odeint_func = getattr(torchdiffeq, 'odeint_adjoint')
else:
    odeint_func = getattr(torchdiffeq, 'odeint')

# define spatial and temporal domain
args.x_range = [args.xmin,args.xmax]
args.t_range = [args.tmin,args.tmax]
logger.info(f'x range {args.x_range}')
logger.info(f't range {args.t_range}')
x_base = np.linspace(0, 1, args.xgrid+1)[:args.xgrid]
t_train = np.linspace(args.tmin, args.tmin+args.tspan*args.tspan_dt, args.tspan+1)[:args.tspan].reshape(-1, 1, 1, 1, 1, 1)
t_test = t_train.copy()
t_eval = t_train.copy()
args.tmax_ind = (np.argmax(np.clip(t_train,0,args.tmax)))
if args.tmax_ind == 0:
    args.tmax_ind +=1

# load data  
Exact_train = []
train_theta = []
visc_train_set = args.visc_train_set.split('-')
args.visc_train_set = visc_train_set 
data_shelve = shelve.open(args.data_path)
for data_i in range(args.num_init_cond):
    Exact_i = data_shelve[f'{data_i}']['data'].astype(np.float32) # (1, T, H, W, 1)
    Exact_train.append(Exact_i)
    theta_i = data_shelve[f'{data_i}']['theta'].astype(np.float32) # (1, T, H, W, 1)
    train_theta.append(theta_i)
name_list_train = ['ns'] * args.num_init_cond

Exact_eval = []
eval_theta = []
visc_eval_set = args.visc_eval_set.split('-')
args.visc_eval_set = visc_eval_set 
data_shelve = shelve.open(args.data_path2)
for data_i in range(args.num_init_cond_eval):
    Exact_i = data_shelve[f'{data_i}']['data'].astype(np.float32) # (1, T, H, W, 1)
    Exact_eval.append(Exact_i)
    theta_i = data_shelve[f'{data_i}']['theta'].astype(np.float32) # (1, T, H, W, 1)
    eval_theta.append(theta_i)
name_list_eval = ['ns'] * args.num_init_cond_eval


logger.info(f't train_set {t_train.reshape(-1)}')
logger.info(f't test_set {t_test.reshape(-1)}')
logger.info(f't eval_set {t_eval.reshape(-1)}')
logger.info(f'tmax ind {args.tmax_ind}')
x = np.tile(x_base.reshape(1, 1, 1, -1, 1, 1), (args.tspan, 1, 1, 1, args.xgrid, 1))
y = np.tile(x_base.reshape(1, 1, 1, 1, -1, 1), (args.tspan, 1, 1, args.xgrid, 1, 1))
t_train = np.tile(t_train, (1, 1, 1, args.xgrid, args.xgrid, 1))
t_eval = np.tile(t_eval, (1, 1, 1, args.xgrid, args.xgrid, 1))
t_test = np.tile(t_test, (1, 1, 1, args.xgrid, args.xgrid, 1))


repeat = args.repeat 

for i in range(args.start_repeat, repeat):
    save_path = os.path.join(args.work_dir, f'results_{i}')
    if args.rank == 0:
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        if not os.path.exists(os.path.join(save_path, 'model')):
            os.mkdir(os.path.join(save_path, 'models'))
        if not os.path.exists(os.path.join(save_path, 'pred')):
            os.mkdir(os.path.join(save_path, 'pred'))
        if not os.path.exists(os.path.join(save_path, 'error')):
            os.mkdir(os.path.join(save_path, 'error'))
        if not os.path.exists(os.path.join(save_path, 'loss_f')):
            os.mkdir(os.path.join(save_path, 'loss_f'))
    # PIDO is defined in net_ns2.py
    model = PIDO(args, odeint_func, i, logger, device, x, y, t_train, t_eval, train_theta, Exact_train, name_list_train, eval_theta, Exact_eval, name_list_eval, args.lr, args.net)

  
    loss_all, loss_u, loss_b, loss_f = model.train_adam(args.adam_lr, args.epoch)
    
    if i != repeat-1:
        del model 



