#!/usr/bin/env python

import time
from collections import OrderedDict

import os
import math
import random
import numpy as np
import xarray as xr
from pathlib import Path

import torch
import torch.nn.functional as F

from meta_modules.models import get_INR
from utils import wrap_lon_to_180
from torchmeta.modules import MetaModule, MetaSequential, MetaLinear

def custom_scaler(coords, labels, mon, woa23_ds):
    x_idx = np.digitize(coords[:,0], woa23_ds.lon.values[:-1])
    y_idx = np.digitize(coords[:,1], woa23_ds.lat.values[:-1])
    
    z_idx = np.digitize(coords[:,2], woa23_ds.depth.values[:-1])
    T_clim = woa23_ds.t_an.values
    idx = np.argwhere(~np.isnan(T_clim[mon, z_idx, y_idx, x_idx]))
    
    coords = coords[idx.ravel()]
    coords = (coords - np.array([-180, -90, 0])) / (np.array([180, 90, 2000]) - np.array([-180, -90, 0]))
    labels = labels - T_clim[mon, z_idx, y_idx, x_idx][:, None]
    labels = labels[idx.ravel()]
    return coords, labels

def sample_depths(z, std_levels):
    indices = np.digitize(z, std_levels, right=True)
    sampled_indices = []
    for level in np.unique(indices):
        in_cluster = np.where(indices == level)[0]
        if len(in_cluster) > 0:
            sampled_index = np.random.choice(in_cluster)
            sampled_indices.append(sampled_index)
    return sampled_indices

def gradient_update_parameters(model,
                               loss,
                               params=None,
                               step_size=None,
                               first_order=False):
    """Update the meta-parameters with one step of gradient descent on the
    loss function, with trainable step sizes for each parameter.

    Parameters
    ----------
    model : `torchmeta.modules.MetaModule` instance
        The model.

    loss : `torch.Tensor` instance
        The value of the inner-loss. This is the result of the training dataset
        through the loss function.

    params : `collections.OrderedDict` instance, optional
        Dictionary containing the meta-parameters of the model. If `None`, then
        the values stored in `model.meta_named_parameters()` are used.

    step_size : `collections.OrderedDict` instance
        Dictionary containing the trainable step sizes for each parameter. The keys
        should match the parameter names in `params`.

    first_order : bool (default: `False`)
        If `True`, then the first order approximation of MAML is used.

    Returns
    -------
    updated_params : `collections.OrderedDict` instance
        Dictionary containing the updated meta-parameters of the model, with one
        gradient update wrt. the inner-loss.
    """
    if not isinstance(model, MetaModule):
        raise ValueError('The model must be an instance of `torchmeta.modules.'
                         'MetaModule`, got `{0}`'.format(type(model)))

    if params is None:
        params = OrderedDict(model.meta_named_parameters())

    grads = torch.autograd.grad(loss,
                                params.values(),
                                create_graph=not first_order)

    updated_params = OrderedDict()

    # Apply layer-wise (per-parameter) step sizes
    for (name, param), grad in zip(params.items(), grads):
        safe_name = name.replace('.', '_')
        updated_params[name] = param - step_size[safe_name] * grad

    return updated_params


class MAMLModel(torch.nn.Module):
    def __init__(self, model, step_size=0.1):
        super(MAMLModel, self).__init__()
        self.model = model
        
        # Initialize trainable step size for each parameter
        self.step_size = torch.nn.ParameterDict()
        for name, param in self.model.meta_named_parameters():
            safe_name = name.replace('.', '_')
            # self.step_size[safe_name] = torch.nn.Parameter(torch.tensor(step_size), requires_grad=True)
            self.step_size[safe_name] = torch.nn.Parameter(torch.ones_like(param) * step_size, requires_grad=True)

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def update_parameters(self, loss, params=None, first_order=False):
        # Pass the trainable step sizes per layer (self.step_size)
        return gradient_update_parameters(self.model,
                                          loss,
                                          params=params,
                                          step_size=self.step_size,
                                          first_order=first_order)

if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    step_size = 0.01
    first_order = False
    Std_level_119 = np.array([   1,    5,   10,   15,   20,   25,   30,   35,   40,   45,   50,
                                55,   60,   65,   70,   75,   80,   85,   90,   95,  100,  110,
                                120,  130,  140,  150,  160,  170,  180,  190,  200,  220,  240,
                                260,  280,  300,  320,  340,  360,  380,  400,  425,  450,  475,
                                500,  525,  550,  575,  600,  625,  650,  675,  700,  750,  800,
                                850,  900,  950, 1000, 1050, 1100, 1150, 1200, 1250, 1300, 1350,
                                1400, 1450, 1500, 1550, 1600, 1650, 1700, 1750, 1800, 1850, 1900,
                                1950, 2000, 2100, 2200, 2300, 2400, 2500, 2600, 2700, 2800, 2900,
                                3000, 3100, 3200, 3300, 3400, 3500, 3600, 3700, 3800, 3900, 4000,
                                4100, 4200, 4300, 4400, 4500, 4600, 4700, 4800, 4900, 5000, 5100,
                                5200, 5300, 5400, 5500, 5600, 5700, 5800, 5900, 6000])
    
    nonlin = 'siren'
    
    nmeas = 1000            # Number of data measurement
    omega0 = 10.0           # Frequency of sinusoid
    sigma0 = 10.0           # Sigma of Gaussian
    
    # Network parameters
    hidden_layers = 2           # Number of hidden layers in the MLP
    hidden_features = 128       # Number of hidden units per layer
    
    if nonlin == 'relu':
        posencode = True
    else:
        posencode = False
    
    net = get_INR(nonlin=nonlin,
                    in_features=3,
                    out_features=2, 
                    hidden_features=hidden_features,
                    hidden_layers=hidden_layers,
                    first_omega_0=omega0,
                    hidden_omega_0=omega0,
                    scale=sigma0,
                    pos_encode=posencode,
                    sidelength=nmeas)
    model = MAMLModel(net, step_size)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.4f' % (n_parameters / 1.e6))
    
    # load ?
    epoch = 0
    outerloss_list = []
    select_yr_min, select_yr_max, select_mon = 2006, 2020, 2
    model = model.to(device)
    model.train()
    print(model)
    
    # Load woa23 climatology.
    woa23_ds = xr.open_dataset('path/to/WOA23', decode_times=False)
   
    # Load dataset.
    data_dir = 'path/to/data'
    dates = np.array(sorted(os.listdir(data_dir)))

    cond_mon = np.array([int(d.split('-')[1])==select_mon for d in dates])
    cond_yr = [int(d.split('-')[0]) >= select_yr_min and int(d.split('-')[0]) <= select_yr_max for d in dates]
    dates = dates[cond_yr * cond_mon]
    year_and_mon = np.array([d[:7] for d in dates])
    # year_and_mon = np.array([d for d in dates])
    output_folder = 'path/to/save'
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    batch_size = len(np.unique(year_and_mon))
    total_epochs = 2000
    print('total number of epochs: ', total_epochs)

    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1.e-3)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=meta_optimizer, T_max=total_epochs, eta_min=1.e-6)
    lr_scheduler.step(epoch)
    
    counter = 0
    base_loss = 0
    outer_loss = torch.tensor(0., device=device)
    model.zero_grad()
    
    # while epoch < total_epochs:
    while epoch < 1000:
        start_time = time.time()

        for ynm in np.unique(year_and_mon):
            in_cluster = np.where(year_and_mon == ynm)[0]
            sampled_index = np.random.choice(in_cluster)
            day = dates[sampled_index]
            
            profiles = []
            mon = int(day.split('-')[1])-1
            data_path = f'{data_dir}/{day}'
            for root, dirs, files in os.walk(data_path):
                for file in files:
                    if file.endswith('.nc'):
                        fname = os.path.join(root, file)
                        profiles.append(fname)
            
            L = len(profiles)
            support, target = [], []
            support_size = int(L * 0.8)
            target_size = L - support_size
            random.shuffle(profiles)

            support_coords, support_labels = [], []
            target_coords, target_labels = [], []
            for i in range(L):
                ds = xr.open_dataset(profiles[i])
                
                z = np.round(ds.depth.values, 0)
                T = np.round(ds.temperature.values, 4)
                
                lon = np.round(ds.lon, 2)
                lat = np.round(ds.lat, 2)
                
                # select levels between 0 and 2000 m.
                cond = (z >= 0) * (z <= 2000)
                z = z[cond]
                T = T[cond]

                # filter no value inputs.
                if len(z) < 1:
                    continue

                # sample levels when training
                sampled_indices = sample_depths(z, Std_level_119)
                z = z[sampled_indices]
                T = T[sampled_indices]
                
                if i < support_size:
                    support_coords.append(np.stack([lon * np.ones_like(z), lat * np.ones_like(z), z], axis=1))
                    support_labels.append(np.stack([T], axis=1))
                else:
                    target_coords.append(np.stack([lon * np.ones_like(z), lat * np.ones_like(z), z], axis=1))
                    target_labels.append(np.stack([T], axis=1))
                ds.close()
                
            support_coords = np.concatenate(support_coords, axis=0)
            support_labels = np.concatenate(support_labels, axis=0)
            target_coords = np.concatenate(target_coords, axis=0)
            target_labels = np.concatenate(target_labels, axis=0)
            
            support_coords, support_labels = custom_scaler(support_coords, support_labels, mon, woa23_ds)
            target_coords, target_labels = custom_scaler(target_coords, target_labels, mon, woa23_ds)
            
            support_coords = torch.from_numpy(support_coords).float().to(device)
            support_labels = torch.from_numpy(support_labels).float().to(device)
            target_coords = torch.from_numpy(target_coords).float().to(device)
            target_labels = torch.from_numpy(target_labels).float().to(device)
            
            support_output = model(support_coords)
            support_pred = support_output[:, :1]
            support_lvar = support_output[:, 1:]

            # uncertainty loss.
            mseloss = (support_pred - support_labels) ** 2 
            inner_loss = torch.mean(torch.exp(-support_lvar) * mseloss + 0.5 * support_lvar)
            params = model.update_parameters(inner_loss, first_order=first_order)
            
            target_pred = model(target_coords, params)[:, :1]
            loss = F.mse_loss(target_pred, target_labels)
            outer_loss += loss
            counter += 1
            initial_loss = F.mse_loss(model(target_coords)[:, :1], target_labels).cpu().detach()
            base_loss += initial_loss.item()
            loss = loss.cpu().detach()
            
            outer_loss += loss
            counter += 1
            initial_loss = F.mse_loss(model(target_coords)[:, :1], target_labels).cpu().detach()
            base_loss += initial_loss.item()
            loss = loss.cpu().detach()
            print(f"ID: {counter}, Date: {day}, Number of profiles collected: {L}, Train, Val and Test split: {support_size}, {target_size}, initialize: {initial_loss.item():.3e}, loss: {loss.item():.3e}.", flush=True)

        outer_loss.div_(batch_size)
        outerloss_list.append(outer_loss.item())
        base_loss = base_loss / batch_size
        output_str = f"Epoch: {epoch}, learning rate: {meta_optimizer.param_groups[0]['lr']:.3e}, averaged inner loss: {base_loss:.3e}, averaged outer loss: {outer_loss.item():.3e}, spend {(time.time()-start_time):.3f}s."
        print(f"##########{output_str}##########", flush=True)     
        with open(f'{output_folder}/log.txt', 'a') as f:
            f.write(output_str)
            f.write('\n')
            f.close()
        meta_optimizer.zero_grad()
        outer_loss.backward()
        meta_optimizer.step()
        lr_scheduler.step()
        
        if (epoch + 1) % 100 == 0:
            torch.save(model.state_dict(), f'{output_folder}/checkpoint_{epoch}')

        counter = 0
        epoch += 1
        base_loss = 0
        outer_loss = torch.tensor(0., device=device)
        model.zero_grad()  
        np.savetxt(f'{output_folder}/outerloss.txt', outerloss_list)              