#!/usr/bin/env python

import time
from collections import OrderedDict

import os
import math
import random
import numpy as np
import xarray as xr
from pathlib import Path

import torch
import torch.nn.functional as F

from meta_modules.models import get_INR
from utils import wrap_lon_to_180
from train import MAMLModel
from torchmeta.modules import MetaModule, MetaSequential, MetaLinear

def custom_scaler(coords, labels, mon, woa23_ds):
    x_idx = np.digitize(coords[:,0], woa23_ds.lon.values[:-1])
    y_idx = np.digitize(coords[:,1], woa23_ds.lat.values[:-1])
    
    z_idx = np.digitize(coords[:,2], woa23_ds.depth.values[:-1])
    T_clim = woa23_ds.t_an.values
    idx = np.argwhere(~np.isnan(T_clim[mon, z_idx, y_idx, x_idx]))
    
    coords = coords[idx.ravel()]
    coords = (coords - np.array([-180, -90, 0])) / (np.array([180, 90, 2000]) - np.array([-180, -90, 0]))
    labels = labels - T_clim[mon, z_idx, y_idx, x_idx][:, None]
    labels = labels[idx.ravel()]
    return coords, labels, idx.ravel()

def sample_depths(z, std_levels):
    indices = np.digitize(z, std_levels, right=True)
    sampled_indices = []
    for level in np.unique(indices):
        in_cluster = np.where(indices == level)[0]
        if len(in_cluster) > 0:
            sampled_index = np.random.choice(in_cluster)
            sampled_indices.append(sampled_index)
    return sampled_indices

def interp2obs(coords, labels, ds, dataset='IAPv4', mon=0):
    if dataset == 'woa23':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.depth.values
        z_idx = np.digitize(coords[:,2], ds.depth.values[:-1])
        Temp = ds.t_an.values[mon, z_idx, y_idx, x_idx][:, None]
    elif dataset == 'IAPv4':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.depth_std.values
        z_idx = np.digitize(coords[:,2], ds.depth_std.values[:-1])
        Temp = ds.temp.values[y_idx, x_idx, z_idx][:, None]
    elif dataset == 'BOA-Argo':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.pres.values
        z_idx = np.digitize(coords[:,2], ds.pres.values[:-1])
        Temp = ds.temp.values[0, z_idx, y_idx, x_idx][:, None]
    
    cond = (~np.isnan(Temp))
    Temp = Temp[cond.ravel()]
    labels = labels[cond.ravel()]
    
    std_lon_idx = np.digitize(coords[:,0], Std_lon[:-1])[:,None]
    std_lat_idx = np.digitize(coords[:,1], Std_lat[:-1])[:,None]
    std_lev_idx = np.digitize(coords[:,2], Std_level_119, right=True)[:,None]
    
    coord_idx = np.concatenate([std_lev_idx, std_lat_idx, std_lon_idx], axis=1)
    coord_idx = coord_idx[cond.ravel()]
    return Temp, labels, coord_idx

if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    step_size = 0.01
    first_order = False
    Std_lat = np.linspace(-89.5, 89.5, 180)
    Std_lon = np.linspace(-179.5, 179.5, 360)
    Std_level_119 = np.array([   1,    5,   10,   15,   20,   25,   30,   35,   40,   45,   50,
                                55,   60,   65,   70,   75,   80,   85,   90,   95,  100,  110,
                                120,  130,  140,  150,  160,  170,  180,  190,  200,  220,  240,
                                260,  280,  300,  320,  340,  360,  380,  400,  425,  450,  475,
                                500,  525,  550,  575,  600,  625,  650,  675,  700,  750,  800,
                                850,  900,  950, 1000, 1050, 1100, 1150, 1200, 1250, 1300, 1350,
                                1400, 1450, 1500, 1550, 1600, 1650, 1700, 1750, 1800, 1850, 1900,
                                1950, 2000, 2100, 2200, 2300, 2400, 2500, 2600, 2700, 2800, 2900,
                                3000, 3100, 3200, 3300, 3400, 3500, 3600, 3700, 3800, 3900, 4000,
                                4100, 4200, 4300, 4400, 4500, 4600, 4700, 4800, 4900, 5000, 5100,
                                5200, 5300, 5400, 5500, 5600, 5700, 5800, 5900, 6000])
    
    nonlin = 'siren'
    nmeas = 1000            # Number of data measurement
    
    omega0 = 10.0           # Frequency of sinusoid
    sigma0 = 10.0           # Sigma of Gaussian
    
    # Network parameters
    hidden_layers = 2          # Number of hidden layers in the MLP
    hidden_features = 128       # Number of hidden units per layer
    
    if nonlin == 'relu':
        posencode = True
    else:
        posencode = False
    
    net = get_INR(nonlin=nonlin,
                  in_features=3,
                  out_features=2, 
                  hidden_features=hidden_features,
                  hidden_layers=hidden_layers,
                  first_omega_0=omega0,
                  hidden_omega_0=omega0,
                  scale=sigma0,
                  pos_encode=posencode,
                  sidelength=nmeas)
    model = MAMLModel(net, step_size)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.4f' % (n_parameters / 1.e6))
    
    model_path = 'path/to/model'
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model = model.to(device)
    model.eval()
    print(model)
    
    # Load woa23 climatology.
    woa23_ds = xr.open_dataset('path/to/WOA23', decode_times=False)
    
    # Load dataset.
    data_dir = '/path/to/data'
    dates = np.array(sorted(os.listdir(data_dir)))

    select_yr_min, select_yr_max, select_mon = 2000, 2005, 1
    cond_mon = np.array([int(d.split('-')[1])==select_mon for d in dates])
    cond_yr = [int(d.split('-')[0]) >= select_yr_min and int(d.split('-')[0]) <= select_yr_max for d in dates]
    dates = dates[cond_yr * cond_mon]
    
    # define output folder to save results
    output_folder = '/path/to/save'
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    total_num_tests = 100
    print('total number of tests: ', total_num_tests)

    counter = 0
    epoch = 0
    
    our_ini_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    our_opt_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    iap_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    boa_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    gdcsm_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    
    while epoch < total_num_tests:
        start_time = time.time()
        for date in dates:
            profiles = []
            year, month, day = date.split('-')
            mon = int(month)-1
            data_path = f'{data_dir}/{date}'
            for root, dirs, files in os.walk(data_path):
                for file in files:
                    if file.endswith('.nc'):
                        fname = os.path.join(root, file)
                        profiles.append(fname)
            
            L = len(profiles)
            support, target = [], []
            support_size = int(L * 0.8)
            target_size = L - support_size
            random.shuffle(profiles)

            support_coords, support_labels = [], []
            target_coords, target_labels = [], []
            for i in range(L):
                ds = xr.open_dataset(profiles[i])
                
                z = np.round(ds.depth.values, 0)
                T = np.round(ds.temperature.values, 4)
                
                lon = np.round(ds.lon, 2)
                lat = np.round(ds.lat, 2)
                
                # select levels between 0 and 2000 m.
                cond = (z >= 0) * (z <= 2000)
                z = z[cond]
                T = T[cond]

                # filter no value inputs.
                if len(z) < 1:
                    continue

                # sample levels when training
                sampled_indices = sample_depths(z, Std_level_119)
                z = z[sampled_indices]
                T = T[sampled_indices]
                
                if i < support_size:
                    support_coords.append(np.stack([lon * np.ones_like(z), lat * np.ones_like(z), z], axis=1))
                    support_labels.append(np.stack([T], axis=1))
                else:
                    target_coords.append(np.stack([lon * np.ones_like(z), lat * np.ones_like(z), z], axis=1))
                    target_labels.append(np.stack([T], axis=1))
                ds.close()
                
            support_coords = np.concatenate(support_coords, axis=0)
            support_labels = np.concatenate(support_labels, axis=0)
            target_coords = np.concatenate(target_coords, axis=0)
            target_labels = np.concatenate(target_labels, axis=0)
            
            iap = xr.open_dataset(f"/path/to/IAPv4", decode_times=False)
            iap = wrap_lon_to_180(iap, center_on_180=False)
            iap_Temp, iap_obs, iap_coord_idx = interp2obs(target_coords, target_labels, iap, dataset="IAPv4", mon=mon)
            for j, cr in enumerate(iap_coord_idx):
                iap_loss[0, cr[0], cr[1], cr[2]] += 1
                iap_loss[1, cr[0], cr[1], cr[2]] += ((iap_Temp[j,0] - iap_obs[j,0])**2)
    
            boa = xr.open_dataset(f"/path/to/BOA-Argo", decode_times=False)
            boa = wrap_lon_to_180(boa, center_on_180=False)
            boa_Temp, boa_obs, boa_coord_idx = interp2obs(target_coords, target_labels, boa, dataset="BOA-Argo", mon=mon)
            for j, cr in enumerate(boa_coord_idx):
                boa_loss[0, cr[0], cr[1], cr[2]] += 1
                boa_loss[1, cr[0], cr[1], cr[2]] += ((boa_Temp[j,0] - boa_obs[j,0])**2)
            
            gdcsm = xr.open_dataset(f"/path/to/GDCSM-Argo", decode_times=False)
            gdcsm = wrap_lon_to_180(gdcsm, center_on_180=False)
            gdcsm_Temp, gdcsm_obs, gdcsm_coord_idx = interp2obs(target_coords, target_labels, gdcsm, dataset="BOA-Argo", mon=mon)
            for j, cr in enumerate(gdcsm_coord_idx):
                gdcsm_loss[0, cr[0], cr[1], cr[2]] += 1
                gdcsm_loss[1, cr[0], cr[1], cr[2]] += ((gdcsm_Temp[j,0] - gdcsm_obs[j,0])**2)
                
            std_lon_idx = np.digitize(target_coords[:,0], Std_lon[:-1])[:,None]
            std_lat_idx = np.digitize(target_coords[:,1], Std_lat[:-1])[:,None]
            std_lev_idx = np.digitize(target_coords[:,2], Std_level_119, right=True)[:,None]
            target_coord_idx = np.concatenate([std_lev_idx, std_lat_idx, std_lon_idx], axis=1)
            
            support_coords, support_labels, _ = custom_scaler(support_coords, support_labels, mon, woa23_ds)
            target_coords, target_labels, woa_filter = custom_scaler(target_coords, target_labels, mon, woa23_ds)
            target_coord_idx = target_coord_idx[woa_filter]
            
            support_coords = torch.from_numpy(support_coords).float().to(device)
            support_labels = torch.from_numpy(support_labels).float().to(device)
            target_coords = torch.from_numpy(target_coords).float().to(device)
            target_labels = torch.from_numpy(target_labels).float().to(device)
            
            support_output = model(support_coords)
            support_pred = support_output[:, :1]
            support_lvar = support_output[:, 1:]

            # uncertainty aware loss.
            mseloss = (support_pred - support_labels) ** 2 
            inner_loss = torch.mean(torch.exp(-support_lvar) * mseloss + 0.5 * support_lvar)
            params = model.update_parameters(inner_loss, first_order=first_order)
            
            target_pred_initial = model(target_coords)[:, :1].cpu().detach().numpy()
            target_pred_update = model(target_coords, params)[:, :1].cpu().detach().numpy()
            target_labels = target_labels.cpu().detach().numpy()
            for j, cr in enumerate(target_coord_idx):
                our_ini_loss[0, cr[0], cr[1], cr[2]] += 1
                our_ini_loss[1, cr[0], cr[1], cr[2]] += ((target_pred_initial[j,0]-target_labels[j,0])**2)
                our_opt_loss[0, cr[0], cr[1], cr[2]] += 1
                our_opt_loss[1, cr[0], cr[1], cr[2]] += ((target_pred_update[j,0]-target_labels[j,0])**2)
            
            counter += 1
            output_str = f"Test: {epoch}, ID: {counter}, Date: {date}, inner: {np.sum(our_ini_loss[1])/np.sum(our_ini_loss[0]):.3f}, outer: {np.sum(our_opt_loss[1])/np.sum(our_opt_loss[0]):.3f}, IAPv4: {np.sum(iap_loss[1])/np.sum(iap_loss[0]):.3f}, spend {(time.time()-start_time):.3f}s."
            print(output_str, flush=True)  
            np.save(f'{output_folder}/inner.npy', our_ini_loss)
            np.save(f'{output_folder}/outer.npy', our_opt_loss)
            np.save(f'{output_folder}/iapv4.npy', iap_loss)
            np.save(f'{output_folder}/boaargo.npy', boa_loss)
            np.save(f'{output_folder}/gdcsmargo.npy', gdcsm_loss)
        epoch += 1     