#### Setup ####
project_dir = '/mnt/r/wasserstein/'
source(paste0(project_dir,'00_functions.R'))

#kernel range parameter
range_km = 1000

#determine resolution of approximation grid
stride = 6
n_lat_convo = 360/stride
n_long_convo = 720/stride

slice_dir = paste0(project_dir,'sliced_',range_km,'km_',stride,'stride/')

#### Prepare TAS datasets ####

convo_coords = readRDS(paste0(slice_dir,'convo_coords.RDS'))
era5_sliced = readRDS(paste0(slice_dir,'tas/era5.RDS'))
ncep_sliced = readRDS(paste0(slice_dir,'tas/ncep.RDS'))

cmip5_dir = paste0(slice_dir,'tas/cmip5/')
cmip5_files = list.files(cmip5_dir)
cmip5_models = sapply(strsplit(cmip5_files,"[.]"),function(x){x[1]})
n_cmip5 = length(cmip5_models)

cmip6_dir = paste0(slice_dir,'tas/cmip6/')
cmip6_files = list.files(cmip6_dir)
cmip6_models = sapply(strsplit(cmip6_files,"[.]"),function(x){x[1]})
n_cmip6 = length(cmip6_models)

#### Calculate distance to ERA5 Renalysis ####

#weights for global
weights = matrix(w_latitude(n_lat_convo),n_lat_convo,n_long_convo)

#weights and indices for tropics
convo_lats = sort(unique(convo_coords$dlat))
tropics_lats = which(abs(convo_lats)<30) #as defined in vissio
tropics_weights = weights[tropics_lats,]

#quantiles of interest
q = seq(0,1,0.005)

#compare era5 and ncep, create objects to store results
era5_quantiles = slice_quantiles(era5_sliced,q)
ncep_quantiles = slice_quantiles(ncep_sliced,q)

tas_results = data.frame(model = c('NCEP Reanalysis',cmip5_models,cmip6_models),
                         mip = c('NCEP Reanalysis',rep('CMIP5',n_cmip5),rep('CMIP6',n_cmip6)),
                         scw = numeric(1+n_cmip5+n_cmip6),
                         scw_tropics = numeric(1+n_cmip5+n_cmip6))
tas_wd_slices = list()

d_global  = scwd(era5_quantiles, ncep_quantiles, weights)
d_tropics = scwd(era5_quantiles[,tropics_lats,], ncep_quantiles[,tropics_lats,], tropics_weights)

tas_results$scw[1] = d_global$scwd
tas_results$scw_tropics[1] = d_tropics$scwd
tas_wd_slices[[1]] = d_global$wd_vals

# distances for CMIP5
for(i in 1:n_cmip5){
  mod_sliced = readRDS(paste0(cmip5_dir,cmip5_files[i]))
  mod_quantiles = slice_quantiles(mod_sliced,q)
  
  d_global  = scwd(era5_quantiles, mod_quantiles, weights)
  d_tropics = scwd(era5_quantiles[,tropics_lats,], mod_quantiles[,tropics_lats,], tropics_weights)
  tas_results$scw[i+1] = d_global$scwd
  tas_results$scw_tropics[i+1] = d_tropics$scwd
  tas_wd_slices[[i+1]] = d_global$wd_vals
}

# distances for CMIP6
for(i in 1:n_cmip6){
  mod_sliced = readRDS(paste0(cmip6_dir,cmip6_files[i]))
  mod_quantiles = slice_quantiles(mod_sliced,q)
  
  d_global  = scwd(era5_quantiles, mod_quantiles, weights)
  d_tropics = scwd(era5_quantiles[,tropics_lats,], mod_quantiles[,tropics_lats,], tropics_weights)
  tas_results$scw[i+1+n_cmip5] = d_global$scwd
  tas_results$scw_tropics[i+1+n_cmip5] = d_tropics$scwd
  tas_wd_slices[[i+1+n_cmip5]] = d_global$wd_vals
}

tas_results %>% arrange(scw)

save(tas_results,tas_wd_slices,file=paste0(slice_dir,'tas/results.RDA'))



#### Prepare PR datasets ####

gpcp_sliced = readRDS(paste0(slice_dir,'pr/gpcp.RDS'))
era5_sliced = readRDS(paste0(slice_dir,'pr/era5.RDS'))
ncep_sliced = readRDS(paste0(slice_dir,'pr/ncep.RDS'))

cmip5_dir = paste0(slice_dir,'pr/cmip5/')
cmip5_files = list.files(cmip5_dir)
cmip5_models = sapply(strsplit(cmip5_files,"[.]"),function(x){x[1]})
n_cmip5 = length(cmip5_models)

cmip6_dir = paste0(slice_dir,'pr/cmip6/')
cmip6_files = list.files(cmip6_dir)
cmip6_models = sapply(strsplit(cmip6_files,"[.]"),function(x){x[1]})
n_cmip6 = length(cmip6_models)

n_days = length(ymd_range(19790101,20051130, calendar = 'standard'))
n_days_365 = length(ymd_range(19790101,20051130, calendar = '365_day'))
n_days_360 = length(ymd_range(19790101,20051130, calendar = '360_day'))

gpcp_period = length(ymd_range(19790101,19961001, calendar = 'standard')):n_days
gpcp_period_365 = length(ymd_range(19790101,19961001, calendar = '365_day')):n_days_365
gpcp_period_360 = length(ymd_range(19790101,19961001, calendar = '360_day')):n_days_360

#### Calculate distance to GPCP ####

#weights for global
weights = matrix(w_latitude(n_lat_convo),n_lat_convo,n_long_convo)

#weights and indices for tropics
convo_coords = readRDS(paste0(slice_dir,'convo_coords.RDS'))
convo_lats = sort(unique(convo_coords$dlat))
tropics_lats = which(abs(convo_lats)<30) #as defined in vissio
tropics_weights = weights[tropics_lats,]

#quantiles of interest
q = seq(0,1,0.005)

#compare era5 and ncep, create objects to store results
gpcp_quantiles = slice_quantiles(gpcp_sliced,q)
era5_quantiles = slice_quantiles(era5_sliced[gpcp_period,,],q)
ncep_quantiles = slice_quantiles(ncep_sliced[gpcp_period,,],q)

pr_results = data.frame(model = c('ERA5','NCEP',cmip5_models,cmip6_models),
                         mip = c('Reanalysis','Reanalysis',rep('CMIP5',n_cmip5),rep('CMIP6',n_cmip6)),
                         scw = numeric(2+n_cmip5+n_cmip6),
                         scw_tropics = numeric(2+n_cmip5+n_cmip6))
pr_wd_slices = list()

#gpcp to era5
d_global  = scwd(gpcp_quantiles, era5_quantiles, weights)
d_tropics = scwd(gpcp_quantiles[,tropics_lats,], era5_quantiles[,tropics_lats,], tropics_weights)

pr_results$scw[1] = d_global$scwd
pr_results$scw_tropics[1] = d_tropics$scwd
pr_wd_slices[[1]] = d_global$wd_vals

#gpcp to ncep
d_global  = scwd(gpcp_quantiles, ncep_quantiles, weights)
d_tropics = scwd(gpcp_quantiles[,tropics_lats,], ncep_quantiles[,tropics_lats,], tropics_weights)

pr_results$scw[2] = d_global$scwd
pr_results$scw_tropics[2] = d_tropics$scwd
pr_wd_slices[[2]] = d_global$wd_vals


# distances for CMIP5
for(i in 1:n_cmip5){
  mod_sliced = readRDS(paste0(cmip5_dir,cmip5_files[i]))
  n_days_mod = dim(mod_sliced)[1]
  if(n_days_mod == n_days){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period,,],q)
  }
  if(n_days_mod == n_days_365){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period_365,,],q)
  }
  if(n_days_mod == n_days_360){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period_360,,],q)
  }
  
  
  d_global  = scwd(gpcp_quantiles, mod_quantiles, weights)
  d_tropics = scwd(gpcp_quantiles[,tropics_lats,], mod_quantiles[,tropics_lats,], tropics_weights)
  pr_results$scw[i+2] = d_global$scwd
  pr_results$scw_tropics[i+2] = d_tropics$scwd
  pr_wd_slices[[i+2]] = d_global$wd_vals
}

# distances for CMIP6
for(i in 1:n_cmip6){
  mod_sliced = readRDS(paste0(cmip6_dir,cmip6_files[i]))
  n_days_mod = dim(mod_sliced)[1]
  if(n_days_mod == n_days){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period,,],q)
  }
  if(n_days_mod == n_days_365){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period_365,,],q)
  }
  if(n_days_mod == n_days_360){
    mod_quantiles = slice_quantiles(mod_sliced[gpcp_period_360,,],q)
  }
  
  d_global  = scwd(gpcp_quantiles, mod_quantiles, weights)
  d_tropics = scwd(gpcp_quantiles[,tropics_lats,], mod_quantiles[,tropics_lats,], tropics_weights)
  pr_results$scw[i+2+n_cmip5] = d_global$scwd
  pr_results$scw_tropics[i+2+n_cmip5] = d_tropics$scwd
  pr_wd_slices[[i+2+n_cmip5]] = d_global$wd_vals
}

pr_results %>% arrange(scw)

save(pr_results,pr_wd_slices,file=paste0(slice_dir,'pr/results.RDA'))
