#!/usr/bin/env python

import time
from collections import OrderedDict

import os
import math
import random
import numpy as np
import xarray as xr
from pathlib import Path

import torch
import torch.nn.functional as F

from meta_modules.models import get_INR
from utils import wrap_lon_to_180
from train import MAMLModel
from torchmeta.modules import MetaModule, MetaSequential, MetaLinear

def custom_scaler(coords, labels, mon, woa23_ds):
    x_idx = np.digitize(coords[:,0], woa23_ds.lon.values[:-1])
    y_idx = np.digitize(coords[:,1], woa23_ds.lat.values[:-1])
    
    z_idx = np.digitize(coords[:,2], woa23_ds.depth.values[:-1])
    T_clim = woa23_ds.t_an.values
    idx = np.argwhere(~np.isnan(T_clim[mon, z_idx, y_idx, x_idx]))
    
    coords = coords[idx.ravel()]
    coords = (coords - np.array([-180, -90, 0])) / (np.array([180, 90, 2000]) - np.array([-180, -90, 0]))
    labels = labels - T_clim[mon, z_idx, y_idx, x_idx][:, None]
    labels = labels[idx.ravel()]
    return coords, labels, idx.ravel()

def sample_depths(z, std_levels):
    indices = np.digitize(z, std_levels, right=True)
    sampled_indices = []
    for level in np.unique(indices):
        in_cluster = np.where(indices == level)[0]
        if len(in_cluster) > 0:
            sampled_index = np.random.choice(in_cluster)
            sampled_indices.append(sampled_index)
    return sampled_indices

def interp2obs(coords, labels, ds, dataset='IAPv4', mon=0):
    if dataset == 'woa23':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.depth.values
        z_idx = np.digitize(coords[:,2], ds.depth.values[:-1])
        Temp = ds.t_an.values[mon, z_idx, y_idx, x_idx][:, None]
    elif dataset == 'IAPv4':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.depth_std.values
        z_idx = np.digitize(coords[:,2], ds.depth_std.values[:-1])
        Temp = ds.temp.values[y_idx, x_idx, z_idx][:, None]
        # Temp = ds.temp.values[mon, y_idx, x_idx, z_idx][:, None]
    elif dataset == 'BOA-Argo':
        x_idx = np.digitize(coords[:,0], ds.lon.values[:-1])
        y_idx = np.digitize(coords[:,1], ds.lat.values[:-1])
        z = ds.pres.values
        z_idx = np.digitize(coords[:,2], ds.pres.values[:-1])
        Temp = ds.temp.values[0, z_idx, y_idx, x_idx][:, None]
    
    cond = (~np.isnan(Temp))
    Temp = Temp[cond.ravel()]
    labels = labels[cond.ravel()]
    # print(Temp.shape, labels.shape)
    
    std_lon_idx = np.digitize(coords[:,0], Std_lon[:-1])[:,None]
    std_lat_idx = np.digitize(coords[:,1], Std_lat[:-1])[:,None]
    std_lev_idx = np.digitize(coords[:,2], Std_level_119, right=True)[:,None]
    
    coord_idx = np.concatenate([std_lev_idx, std_lat_idx, std_lon_idx], axis=1)
    # print(coord_idx.shape, cond.shape)
    coord_idx = coord_idx[cond.ravel()]
    return Temp, labels, coord_idx

if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    step_size = 0.01
    first_order = False
    Std_lat = np.linspace(-89.5, 89.5, 180)
    Std_lon = np.linspace(-179.5, 179.5, 360)
    Std_level_119 = np.array([   1,    5,   10,   15,   20,   25,   30,   35,   40,   45,   50,
                                55,   60,   65,   70,   75,   80,   85,   90,   95,  100,  110,
                                120,  130,  140,  150,  160,  170,  180,  190,  200,  220,  240,
                                260,  280,  300,  320,  340,  360,  380,  400,  425,  450,  475,
                                500,  525,  550,  575,  600,  625,  650,  675,  700,  750,  800,
                                850,  900,  950, 1000, 1050, 1100, 1150, 1200, 1250, 1300, 1350,
                                1400, 1450, 1500, 1550, 1600, 1650, 1700, 1750, 1800, 1850, 1900,
                                1950, 2000, 2100, 2200, 2300, 2400, 2500, 2600, 2700, 2800, 2900,
                                3000, 3100, 3200, 3300, 3400, 3500, 3600, 3700, 3800, 3900, 4000,
                                4100, 4200, 4300, 4400, 4500, 4600, 4700, 4800, 4900, 5000, 5100,
                                5200, 5300, 5400, 5500, 5600, 5700, 5800, 5900, 6000])
    
    nonlin = 'siren'
    nmeas = 1000            # Number of data measurement
    
    omega0 = 10.0           # Frequency of sinusoid
    sigma0 = 10.0           # Sigma of Gaussian
    
    # Network parameters
    hidden_layers = 2          # Number of hidden layers in the MLP
    hidden_features = 128       # Number of hidden units per layer
    
    if nonlin == 'relu':
        posencode = True
    else:
        posencode = False
    
    net = get_INR(nonlin=nonlin,
                  in_features=3,
                  out_features=2, 
                  hidden_features=hidden_features,
                  hidden_layers=hidden_layers,
                  first_omega_0=omega0,
                  hidden_omega_0=omega0,
                  scale=sigma0,
                  pos_encode=posencode,
                  sidelength=nmeas)
    model = MAMLModel(net, step_size)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.4f' % (n_parameters / 1.e6))
    
    # Load model parameters.
    model_path = 'path/to/model'
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model = model.to(device)
    model.eval()
    print(model)
    
    # Load woa23 climatology.
    woa23_ds = xr.open_dataset('path/to/WOA23', decode_times=False)
    
    # Load dataset.
    data_dir = 'path/to/data'
    dates = np.array(sorted(os.listdir(data_dir)))

    select_yr_min, select_yr_max, select_mon = 2007, 2022, 1
    cond_mon = np.array([int(d.split('-')[1])==select_mon for d in dates])
    cond_yr = [int(d.split('-')[0]) >= select_yr_min and int(d.split('-')[0]) <= select_yr_max for d in dates]
    dates = dates[cond_yr * cond_mon]
    output_folder = 'path/to/save'
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    counter = 0
    epoch = 0
    
    our_ini_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    our_opt_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    iap_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    boa_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    gdcsm_loss = np.zeros((2, len(Std_level_119), len(Std_lat), len(Std_lon)))
    
    total_epochs = 100
    while epoch < total_epochs:
        start_time = time.time()
        for date in dates:
            profiles = []
            year, month, day = date.split('-')
            mon = int(month)-1
            data_path = f'{data_dir}/{date}'
            for root, dirs, files in os.walk(data_path):
                for file in files:
                    if file.endswith('.nc'):
                        fname = os.path.join(root, file)
                        profiles.append(fname)
            
            L = len(profiles)
            
            metoffice_sst_l4 = xr.open_dataset('/path/to/OSTIA/data')
            metlon = metoffice_sst_l4.lon.values
            metlat = metoffice_sst_l4.lat.values
            metlev = np.array([0])
            xx = np.tile(metlon.reshape(1,1,-1), (len(metlev),len(metlat),1))
            yy = np.tile(metlat.reshape(1,-1,1), (len(metlev),1,len(metlon)))
            zz = np.tile(metlev.reshape(-1,1,1), (1,len(metlat),len(metlon)))
            target_labels = np.round(metoffice_sst_l4.analysed_sst.values - 273.15, 3)
            cond = (~np.isnan(target_labels))
            target_coords = np.stack([xx[cond], yy[cond], zz[cond]], axis=1)
            target_labels = target_labels[cond][:, None]
            
            idx = np.random.choice(len(target_coords), int(0.01*len(target_coords)), replace=False)
            target_coords = target_coords[idx, :]
            target_labels = target_labels[idx, :]
            support_coords, support_labels = [], []
            for i in range(L):
                ds = xr.open_dataset(profiles[i])
                z = np.round(ds.depth.values, 0)
                T = np.round(ds.temperature.values, 4)
                lon = np.round(ds.lon, 2)
                lat = np.round(ds.lat, 2)
            
                # select levels between 0 and 2000 m.
                cond = (z >= 0) * (z <= 2000)
                z = z[cond]
                T = T[cond]

                # filter no value inputs.
                if len(z) < 1:
                    continue

                support_coords.append(np.stack([lon * np.ones_like(z), lat * np.ones_like(z), z], axis=1))
                support_labels.append(np.stack([T], axis=1))
                ds.close()
                
            support_coords = np.concatenate(support_coords, axis=0)
            support_labels = np.concatenate(support_labels, axis=0)
            
            iap = xr.open_dataset('path/to/IAPv4')
            iap = wrap_lon_to_180(iap, center_on_180=False)
            iap_Temp, iap_obs, iap_coord_idx = interp2obs(target_coords, target_labels, iap, dataset="IAPv4", mon=mon)
            for j, cr in enumerate(iap_coord_idx):
                iap_loss[0, cr[0], cr[1], cr[2]] += 1
                iap_loss[1, cr[0], cr[1], cr[2]] += ((iap_Temp[j,0] - iap_obs[j,0])**2)
    
            boa = xr.open_dataset('path/to/BOA-Argo')
            boa = wrap_lon_to_180(boa, center_on_180=False)
            boa_Temp, boa_obs, boa_coord_idx = interp2obs(target_coords, target_labels, boa, dataset="BOA-Argo", mon=mon)
            for j, cr in enumerate(boa_coord_idx):
                boa_loss[0, cr[0], cr[1], cr[2]] += 1
                boa_loss[1, cr[0], cr[1], cr[2]] += ((boa_Temp[j,0] - boa_obs[j,0])**2)
            
            gdcsm = xr.open_dataset('path/to/GDCSM-Argo', decode_times=False)
            gdcsm = wrap_lon_to_180(gdcsm, center_on_180=False)
            gdcsm_Temp, gdcsm_obs, gdcsm_coord_idx = interp2obs(target_coords, target_labels, gdcsm, dataset="BOA-Argo", mon=mon)
            for j, cr in enumerate(gdcsm_coord_idx):
                gdcsm_loss[0, cr[0], cr[1], cr[2]] += 1
                gdcsm_loss[1, cr[0], cr[1], cr[2]] += ((gdcsm_Temp[j,0] - gdcsm_obs[j,0])**2)
                
            std_lon_idx = np.digitize(target_coords[:,0], Std_lon[:-1])[:,None]
            std_lat_idx = np.digitize(target_coords[:,1], Std_lat[:-1])[:,None]
            std_lev_idx = np.digitize(target_coords[:,2], Std_level_119, right=True)[:,None]
            target_coord_idx = np.concatenate([std_lev_idx, std_lat_idx, std_lon_idx], axis=1)
            
            support_coords, support_labels, _ = custom_scaler(support_coords, support_labels, mon, woa23_ds)
            target_coords, target_labels, woa_filter = custom_scaler(target_coords, target_labels, mon, woa23_ds)
            target_coord_idx = target_coord_idx[woa_filter]
            
            support_coords = torch.from_numpy(support_coords).float().to(device)
            support_labels = torch.from_numpy(support_labels).float().to(device)
            target_coords = torch.from_numpy(target_coords).float().to(device)
            target_labels = torch.from_numpy(target_labels).float().to(device)
            
            support_output = model(support_coords)
            support_pred = support_output[:, :1]
            support_lvar = support_output[:, 1:]

            # uncertainty loss.
            mseloss = (support_pred - support_labels) ** 2 
            inner_loss = torch.mean(torch.exp(-support_lvar) * mseloss + 0.5 * support_lvar)
            params = model.update_parameters(inner_loss, first_order=first_order)
            
            target_pred_initial = model(target_coords)[:, :1].cpu().detach().numpy()
            target_pred_update = model(target_coords, params)[:, :1].cpu().detach().numpy()
            target_labels = target_labels.cpu().detach().numpy()
            for j, cr in enumerate(target_coord_idx):
                our_ini_loss[0, cr[0], cr[1], cr[2]] += 1
                our_ini_loss[1, cr[0], cr[1], cr[2]] += ((target_pred_initial[j,0]-target_labels[j,0])**2)
                our_opt_loss[0, cr[0], cr[1], cr[2]] += 1
                our_opt_loss[1, cr[0], cr[1], cr[2]] += ((target_pred_update[j,0]-target_labels[j,0])**2)
            
            counter += 1
            output_str = f"Epoch: {epoch}, ID: {counter}, Date: {date}, inner: {np.sum(our_ini_loss[1])/np.sum(our_ini_loss[0]):.3f}, outer: {np.sum(our_opt_loss[1])/np.sum(our_opt_loss[0]):.3f}, IAPv4: {np.sum(iap_loss[1])/np.sum(iap_loss[0]):.3f}, BOA-Argo: {np.sum(boa_loss[1])/np.sum(boa_loss[0]):.3f}, GDCSM-Argo: {np.sum(gdcsm_loss[1])/np.sum(gdcsm_loss[0]):.3f}, spend {(time.time()-start_time):.3f}s."
            print(output_str, flush=True)  
            np.save(f'{output_folder}/inner.npy', our_ini_loss)
            np.save(f'{output_folder}/outer.npy', our_opt_loss)
            np.save(f'{output_folder}/iapv4.npy', iap_loss)
            np.save(f'{output_folder}/boaargo.npy', boa_loss)
            np.save(f'{output_folder}/gdcsmargo.npy', gdcsm_loss)
        epoch += 1     