import torch
import pytorch_lightning as pl
import numpy as np
import pandas as pd
from typing import Any
from omegaconf import DictConfig
from einops import rearrange
from models.KHINR_net import KHINRNet
from utils import *
from torchmetrics.regression import MeanSquaredError
from torchmetrics.image import (PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure)
from einops import rearrange
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import TwoSlopeNorm
from scipy.ndimage import zoom
import matplotlib.colors as mcolors

    
#---------------------------------------------------------
# get model
#---------------------------------------------------------
def get_model(cfg, data_name = None):
    """
    Set model.
    Args:
        cfg: Model configuration.
    Returns:
        Model will be use for modeling.
    """
    if cfg.name == "KHINR":
        # n_block, n_mode, n_dim, n_head, n_layer, x_dim, y1_dim, y2_dim, attn, act
        model = KHINRNet(
                    n_block=cfg.n_block,
                    n_mode=cfg.n_mode,
                    n_dim=cfg.n_dim,
                    n_head=cfg.n_head,
                    n_layer=cfg.n_layer,
                    x_dim=cfg.x_dim,
                    y1_dim=cfg.y1_dim,
                    y2_dim=cfg.y2_dim,
                    attn=cfg.attn,
                    act=cfg.act,
                    data=data_name
                    )
                
    return model

def denormalize_lat(x):
    minVal = -89.5000
    maxVal = 89.5000
    epsilon = 1e-10

    return (x-epsilon)*(maxVal - minVal) + minVal


def denormalize_lon(x):
    minVal = -179.5000
    maxVal = 179.5000
    epsilon = 1e-10

    return (x-epsilon)*(maxVal - minVal) + minVal

def bilinear_interpolate(ds, resolution, lat_minmax, lon_minmax):

    print(f"Interpolating , resolution :{resolution}")
    new_lat = np.arange(lat_minmax[0], lat_minmax[1] + resolution, resolution)
    new_lon = np.arange(lon_minmax[0], lon_minmax[1] + resolution, resolution)

    # Perform bilinear interpolation
    ds_high_res = ds['sla'].interp(latitude=new_lat, longitude=new_lon, method="linear")

    return ds_high_res
#---------------------------------------------------------
# get model
#---------------------------------------------------------
def plotSample(yhat, yref, dir_save, sample_name):
    """
    Args:
        yhat (numpy.array): (b, lat, lon)
        yref (numpy.array): (b, lat, lon)
    """
    cmap = plt.get_cmap("RdBu_r")
    plt.close("all")
    vmin_ref = yref.min()
    vmax_ref= yref.max()

    vmin_hat = yhat.min()
    vmax_hat= yhat.max()

    yref = (yref- vmin_ref)/ (vmax_ref-vmin_ref)
    yhat = (yhat-vmin_hat)/(vmax_hat-vmin_hat)

    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
 
    cset1 = ax0.imshow(yref, cmap=cmap)#, vmin=vmin_ref, vmax=vmax_ref)
    ax0.set_xticks([], [])
    ax0.set_yticks([], [])
    fig.colorbar(cset1, ax=ax0)
 
    cset2 = ax1.imshow(yhat, cmap=cmap)#, vmin=vmin_ref, vmax=vmax_ref)
    ax1.set_xticks([], [])
    ax1.set_yticks([], [])
    fig.colorbar(cset2, ax=ax1)
    plt.savefig(dir_save + "/" + sample_name + ".png", bbox_inches="tight", dpi=300)

# CHL
# def plotSample2(x, yhat, yref, dir_save, sample_name, data):
#     """
#     Args:
#         yhat (numpy.array): (b, lat, lon)
#         yref (numpy.array): (b, lat, lon)
#     """
    
#     df = pd.DataFrame({
#     'lat': x[:, 0],
#     'lon': x[:, 1],
#     'data1': yref, 
#     'data2' : yhat
#     })

    
#     # Pivot the table to form the 2D grid!
#     reconstructed_grid_np_ref = df.pivot(index='lat', columns='lon', values='data1')
#     reconstructed_grid_np_hat = df.pivot(index='lat', columns='lon', values='data2')
#     # Sort the index (lats) descending for conventional map view
#     reconstructed_grid_np_ref = reconstructed_grid_np_ref.sort_index(ascending=False)
#     reconstructed_grid_np_hat = reconstructed_grid_np_hat.sort_index(ascending=False)
#     # If you need it as a NumPy array:
#     yref = reconstructed_grid_np_ref.to_numpy()
#     yhat = reconstructed_grid_np_hat.to_numpy()

#     # cmap = plt.get_cmap("PuOr_r").copy()
#     # cmap.set_bad(color='lightgrey') 

    
#     cmap = plt.get_cmap("RdBu_r")
#     cmap.set_bad(color='lightgrey')
   
#     # reconstructed_grid_np_hat[np.isnan(reconstructed_grid_np_hat)] = 0
#     # reconstructed_grid_np_ref[np.isnan(reconstructed_grid_np_ref)] = 0
#     # plt.close("all")
#     # vmin_ref = np.nanmin(reconstructed_grid_np_ref)
#     # vmax_ref = np.nanmax(reconstructed_grid_np_ref)

#     # vmin_hat = np.nanmin(reconstructed_grid_np_hat)
#     # vmax_hat = np.nanmax(reconstructed_grid_np_hat)

#     # min_ = min(vmin_ref, vmin_hat)
#     # max_ = max(vmax_ref, vmax_hat)
#     # yref = (reconstructed_grid_np_ref- vmin_ref)/ (vmax_ref-vmin_ref)
#     # yhat = (reconstructed_grid_np_hat-vmin_hat)/(vmax_hat-vmin_hat)
    
#     yref = reconstructed_grid_np_ref
#     yhat = reconstructed_grid_np_hat

#     # print(yref.shape)
#     # print(yhat.shape)
#     fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
#     fig.suptitle('Comparison b/w ref and pred', fontsize=16)

#     # --- Plot 1: Reference Data ---
#     # Use the same vmin and vmax to ensure the data is scaled identically

#     if data == "ssh":
#         vmin = -1
#         vmax = 1

#     else:
#         vmin=0 
#         vmax = 1
    
#     plt.figure(figsize=(6, 5))
#     plt.imshow(yhat, cmap=cmap, vmin= vmin, vmax=vmax)
#     plt.axis('off')
#     plt.savefig(dir_save + "/" + sample_name + ".png", bbox_inches="tight", dpi=300)
#     plt.close()

# # CHL
# def plot_Subsamples(x, yhat, yref, dir_save, data):
    

#     if data != 'chl':
#         return

#     lat_tuples = [(-10, 70),(-15, 75)  , (-30, 40) , (-90, -7), (-90, 14)]
#     lon_tuples = [(-180, -100) ,(-90, 0), (110, 180) ,(-70, 10), (-180, -80)]
    

#     # cmap = plt.get_cmap("PuOr_r").copy()
#     # cmap.set_bad(color='lightgrey')

#     df = pd.DataFrame({
#         'lat': denormalize_lat(x[:, 0]),
#         'lon': denormalize_lon(x[:, 1]),
#         'data1': yref,
#         'data2': yhat
#     })


  
#     count = 0
#     for (lat_r, lon_r) in zip(lat_tuples, lon_tuples):
#         # make a fresh copy to avoid cumulative filtering
#         df_sub = df.copy()
        
     
#         lat_min, lat_max = lat_r
#         lon_min, lon_max = lon_r

#         # filter the copy only
#         if lat_min is not None and lat_max is not None:
#             df_sub = df_sub[(df_sub['lat'] >= lat_min) & (df_sub['lat'] <= lat_max)]
#         if lon_min is not None and lon_max is not None:
#             df_sub = df_sub[(df_sub['lon'] >= lon_min) & (df_sub['lon'] <= lon_max)]

#         # if no points in this subregion, skip (or log)
#         if df_sub.empty:
#             print(f"Subregion {count}: empty (lat {lat_min},{lat_max} lon {lon_min},{lon_max}) — skipping")
#             count += 1
#             continue

#         # use pivot_table to handle duplicates safely
#         recon_ref = df_sub.pivot_table(index='lat', columns='lon', values='data1', aggfunc='mean')
#         recon_hat = df_sub.pivot_table(index='lat', columns='lon', values='data2', aggfunc='mean')

#         # sort descending latitude for conventional map view
#         recon_ref = recon_ref.sort_index(ascending=False)
#         recon_hat = recon_hat.sort_index(ascending=False)

#         # If there is only 1 row or 1 column, pad the index/columns so imshow has at least 2x2
#         # def pad_to_min_size(df_in, min_rows=2, min_cols=2, pad_delta=0.01):
#         #     rows = list(df_in.index)
#         #     cols = list(df_in.columns)
#         #     # if only one row, add a tiny lat neighbor
#         #     if len(rows) < min_rows:
#         #         if len(rows) == 1:
#         #             r0 = rows[0]
#         #             new_rows = sorted([r0, r0 - pad_delta], reverse=True)  # keep descending order
#         #         else:
#         #             new_rows = [rows[0] - pad_delta, rows[0]]
#         #         df_in = df_in.reindex(new_rows)
#         #     # if only one col, add a tiny lon neighbor
#         #     if len(cols) < min_cols:
#         #         if len(cols) == 1:
#         #             c0 = cols[0]
#         #             new_cols = sorted([c0, c0 + pad_delta])
#         #         else:
#         #             new_cols = [cols[0], cols[0] + pad_delta]
#         #         df_in = df_in.reindex(columns=new_cols)
#         #     return df_in

#         # recon_ref = pad_to_min_size(recon_ref)
#         # recon_hat = pad_to_min_size(recon_hat)

#         yref_img = recon_ref.to_numpy()
#         yhat_img = recon_hat.to_numpy()

#         yref_img = zoom(yref_img, (4,4), order=1, mode='constant')
#         yhat_img = zoom(yhat_img, (4,4), order=1, mode='constant')

#         # Save the numpy arrays
#         np.save(f"{dir_save}/yhat_img_{count}.npy", yhat_img)
#         np.save(f"{dir_save}/yref_img_{count}.npy", yref_img)
#         plt.figure(figsize=(7, 7))
#         levels = np.linspace(0, 1, 11)
#         cmap = plt.get_cmap("RdBu_r", len(levels)-1)
#         cmap.set_bad(color='lightgrey')
#         norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)

#         plt.imshow(yhat_img, cmap=cmap, norm=norm)
#         plt.xticks([])
#         plt.yticks([])
#         plt.savefig(f"{dir_save}/{count}.png", bbox_inches="tight", dpi=300)
#         plt.close()
#         count += 1


# SSH
def plotSample2(x, yhat, yref, dir_save, sample_name, data):
    """
    Args:
        yhat (numpy.array): (b, lat, lon)
        yref (numpy.array): (b, lat, lon)
    """
    
    df = pd.DataFrame({
    'lat': x[:, 0],
    'lon': x[:, 1],
    'data1': yref, 
    'data2' : yhat
    })

    
    # Pivot the table to form the 2D grid!
    reconstructed_grid_np_ref = df.pivot(index='lat', columns='lon', values='data1')
    reconstructed_grid_np_hat = df.pivot(index='lat', columns='lon', values='data2')
    # Sort the index (lats) descending for conventional map view
    reconstructed_grid_np_ref = reconstructed_grid_np_ref.sort_index(ascending=False)
    reconstructed_grid_np_hat = reconstructed_grid_np_hat.sort_index(ascending=False)
    # If you need it as a NumPy array:
    yref = reconstructed_grid_np_ref.to_numpy()
    yhat = reconstructed_grid_np_hat.to_numpy()

    # cmap = plt.get_cmap("PuOr_r").copy()
    # cmap.set_bad(color='lightgrey') 

    levels = np.linspace(-0.75, 0.75, 11)
    cmap = plt.get_cmap("PuOr_r", len(levels)-1)
    cmap.set_bad(color='lightgrey')
    norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
    # reconstructed_grid_np_hat[np.isnan(reconstructed_grid_np_hat)] = 0
    # reconstructed_grid_np_ref[np.isnan(reconstructed_grid_np_ref)] = 0
    # plt.close("all")
    # vmin_ref = np.nanmin(reconstructed_grid_np_ref)
    # vmax_ref = np.nanmax(reconstructed_grid_np_ref)

    # vmin_hat = np.nanmin(reconstructed_grid_np_hat)
    # vmax_hat = np.nanmax(reconstructed_grid_np_hat)

    # min_ = min(vmin_ref, vmin_hat)
    # max_ = max(vmax_ref, vmax_hat)
    # yref = (reconstructed_grid_np_ref- vmin_ref)/ (vmax_ref-vmin_ref)
    # yhat = (reconstructed_grid_np_hat-vmin_hat)/(vmax_hat-vmin_hat)
    
    yref = reconstructed_grid_np_ref
    yhat = reconstructed_grid_np_hat

    # print(yref.shape)
    # print(yhat.shape)
    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
    fig.suptitle('Comparison b/w ref and pred', fontsize=16)

    # --- Plot 1: Reference Data ---
    # Use the same vmin and vmax to ensure the data is scaled identically

    if data == "ssh":
        vmin = -1
        vmax = 1

    else:
        vmin=0 
        vmax = 1
    

    plt.figure(figsize=(6, 5))
    plt.imshow(yhat, cmap=cmap, norm= norm)
    plt.axis('off')
    plt.savefig(dir_save + "/" + sample_name + ".png", bbox_inches="tight", dpi=300)
    plt.close()




# # SSH
def plot_Subsamples(x, yhat, yref, dir_save, data):
    

    if data != 'ssh':
        return

    lat_tuples = [(31, 51), (5, 65), (-55, 15), (-85, -5), (-70, -15), (15, 60)]
    lon_tuples = [(42, 61), (120, 180), (102, 172), (-3.3, 63.5), (-72, -17), (-80, -35)]
    

    # cmap = plt.get_cmap("PuOr_r").copy()
    # cmap.set_bad(color='lightgrey')

    df = pd.DataFrame({
        'lat': denormalize_lat(x[:, 0]),
        'lon': denormalize_lon(x[:, 1]),
        'data1': yref,
        'data2': yhat
    })


  
    count = 0
    for (lat_r, lon_r) in zip(lat_tuples, lon_tuples):
        # make a fresh copy to avoid cumulative filtering
        df_sub = df.copy()
        
     
        lat_min, lat_max = lat_r
        lon_min, lon_max = lon_r

        # filter the copy only
        if lat_min is not None and lat_max is not None:
            df_sub = df_sub[(df_sub['lat'] >= lat_min) & (df_sub['lat'] <= lat_max)]
        if lon_min is not None and lon_max is not None:
            df_sub = df_sub[(df_sub['lon'] >= lon_min) & (df_sub['lon'] <= lon_max)]

        # if no points in this subregion, skip (or log)
        if df_sub.empty:
            print(f"Subregion {count}: empty (lat {lat_min},{lat_max} lon {lon_min},{lon_max}) — skipping")
            count += 1
            continue

        # use pivot_table to handle duplicates safely
        recon_ref = df_sub.pivot_table(index='lat', columns='lon', values='data1', aggfunc='mean')
        recon_hat = df_sub.pivot_table(index='lat', columns='lon', values='data2', aggfunc='mean')

        # sort descending latitude for conventional map view
        recon_ref = recon_ref.sort_index(ascending=False)
        recon_hat = recon_hat.sort_index(ascending=False)

        # If there is only 1 row or 1 column, pad the index/columns so imshow has at least 2x2
        # def pad_to_min_size(df_in, min_rows=2, min_cols=2, pad_delta=0.01):
        #     rows = list(df_in.index)
        #     cols = list(df_in.columns)
        #     # if only one row, add a tiny lat neighbor
        #     if len(rows) < min_rows:
        #         if len(rows) == 1:
        #             r0 = rows[0]
        #             new_rows = sorted([r0, r0 - pad_delta], reverse=True)  # keep descending order
        #         else:
        #             new_rows = [rows[0] - pad_delta, rows[0]]
        #         df_in = df_in.reindex(new_rows)
        #     # if only one col, add a tiny lon neighbor
        #     if len(cols) < min_cols:
        #         if len(cols) == 1:
        #             c0 = cols[0]
        #             new_cols = sorted([c0, c0 + pad_delta])
        #         else:
        #             new_cols = [cols[0], cols[0] + pad_delta]
        #         df_in = df_in.reindex(columns=new_cols)
        #     return df_in

        # recon_ref = pad_to_min_size(recon_ref)
        # recon_hat = pad_to_min_size(recon_hat)

        yref_img = recon_ref.to_numpy()
        yhat_img = recon_hat.to_numpy()

        yref_img = zoom(yref_img, (8,8), order=1, mode= 'constant')
        yhat_img = zoom(yhat_img, (8,8), order=1, mode= 'constant')

        plt.figure(figsize=(7, 6))
        levels = np.linspace(-0.75, 0.75, 11)
        cmap = plt.get_cmap("PuOr_r", len(levels)-1)
        cmap.set_bad(color='lightgrey')
        norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)

        plt.imshow(yref_img, cmap=cmap, norm=norm)
        plt.xticks([])
        plt.yticks([])
        plt.savefig(f"{dir_save}/{count}.png", bbox_inches="tight", dpi=300)
        plt.close()
        count += 1


def plot_Subsamples(x, yhat, yref, dir_save, data):
    

    if data != 'ssh':
        return

    lat_tuples = [(20, 60), (20, 50) , (20, 65) , (-75, -30)]
    lon_tuples = [(-180, -120) , (-60, -10), (125, 180) ,(-50, 20)]
    

    # cmap = plt.get_cmap("PuOr_r").copy()
    # cmap.set_bad(color='lightgrey')

    df = pd.DataFrame({
        'lat': denormalize_lat(x[:, 0]),
        'lon': denormalize_lon(x[:, 1]),
        'data1': yref,
        'data2': yhat
    })


  
    count = 0
    for (lat_r, lon_r) in zip(lat_tuples, lon_tuples):
        # make a fresh copy to avoid cumulative filtering
        df_sub = df.copy()
        
     
        lat_min, lat_max = lat_r
        lon_min, lon_max = lon_r

        # filter the copy only
        if lat_min is not None and lat_max is not None:
            df_sub = df_sub[(df_sub['lat'] >= lat_min) & (df_sub['lat'] <= lat_max)]
        if lon_min is not None and lon_max is not None:
            df_sub = df_sub[(df_sub['lon'] >= lon_min) & (df_sub['lon'] <= lon_max)]

        # if no points in this subregion, skip (or log)
        if df_sub.empty:
            print(f"Subregion {count}: empty (lat {lat_min},{lat_max} lon {lon_min},{lon_max}) — skipping")
            count += 1
            continue

        # use pivot_table to handle duplicates safely
        recon_ref = df_sub.pivot_table(index='lat', columns='lon', values='data1', aggfunc='mean')
        recon_hat = df_sub.pivot_table(index='lat', columns='lon', values='data2', aggfunc='mean')

        # sort descending latitude for conventional map view
        recon_ref = recon_ref.sort_index(ascending=False)
        recon_hat = recon_hat.sort_index(ascending=False)

        # If there is only 1 row or 1 column, pad the index/columns so imshow has at least 2x2
        # def pad_to_min_size(df_in, min_rows=2, min_cols=2, pad_delta=0.01):
        #     rows = list(df_in.index)
        #     cols = list(df_in.columns)
        #     # if only one row, add a tiny lat neighbor
        #     if len(rows) < min_rows:
        #         if len(rows) == 1:
        #             r0 = rows[0]
        #             new_rows = sorted([r0, r0 - pad_delta], reverse=True)  # keep descending order
        #         else:
        #             new_rows = [rows[0] - pad_delta, rows[0]]
        #         df_in = df_in.reindex(new_rows)
        #     # if only one col, add a tiny lon neighbor
        #     if len(cols) < min_cols:
        #         if len(cols) == 1:
        #             c0 = cols[0]
        #             new_cols = sorted([c0, c0 + pad_delta])
        #         else:
        #             new_cols = [cols[0], cols[0] + pad_delta]
        #         df_in = df_in.reindex(columns=new_cols)
        #     return df_in

        # recon_ref = pad_to_min_size(recon_ref)
        # recon_hat = pad_to_min_size(recon_hat)

        yref_img = recon_ref.to_numpy()
        yhat_img = recon_hat.to_numpy()

        yref_img = zoom(yref_img, (4,4), order=1, mode= 'constant')
        yhat_img = zoom(yhat_img, (4,4), order=1, mode= 'constant')

        fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
        fig.suptitle('Comparison b/w ref and pred (Subregions)', fontsize=16)

        # vmin, vmax = -0.75, 0.75

        levels = np.linspace(-0.75, 0.75, 11)
        cmap = plt.get_cmap("PuOr_r", len(levels)-1)
        cmap.set_bad(color='lightgrey')
        norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)

        ax0.imshow(yref_img, cmap=cmap,  norm= norm)
        ax0.set_title('Reference (y_ref)')
        ax0.set_xticks([])
        ax0.set_yticks([])

        cset2 = ax1.imshow(yhat_img, cmap=cmap,  norm= norm)
        ax1.set_title('Prediction (y_hat)')
        ax1.set_xticks([])
        ax1.set_yticks([])

        fig.colorbar(cset2, ax=[ax0, ax1])
        plt.savefig(f"{dir_save}/{count}.png", bbox_inches="tight", dpi=300)
        plt.close(fig)
        count += 1
    

#---------------------------------------------------------
# plotting difference
#---------------------------------------------------------
def plotdiff(yhat, yref, dir_save, sample_name):
    colors = ["blue", "white", "red"]
    cmap = LinearSegmentedColormap.from_list("blue_red", colors)
    # Compute difference
    diff = yhat - yref
    # Plot difference using imshow
    plt.figure(figsize=(6, 5))
    im = plt.imshow(diff, cmap=cmap)
    plt.colorbar(im)
    plt.title("Difference")
    plt.savefig(dir_save + "/" + sample_name + "difference.png", bbox_inches="tight")
    plt.close()


#---------------------------------------------------------
# plotting difference 2
#---------------------------------------------------------
def plotdiff2(x, yhat, yref, dir_save, sample_name):

    df = pd.DataFrame({
    'lat': x[:, 0],
    'lon': x[:, 1],
    'data1': yref, 
    'data2' : yhat
    })

    # Pivot the table to form the 2D grid!
    reconstructed_grid_np_ref = df.pivot(index='lat', columns='lon', values='data1')
    reconstructed_grid_np_hat = df.pivot(index='lat', columns='lon', values='data2')
    # Sort the index (lats) descending for conventional map view
    reconstructed_grid_np_ref = reconstructed_grid_np_ref.sort_index(ascending=False)
    reconstructed_grid_np_hat = reconstructed_grid_np_hat.sort_index(ascending=False)
    # If you need it as a NumPy array:
    reconstructed_grid_np_ref = reconstructed_grid_np_ref.to_numpy()
    reconstructed_grid_np_hat = reconstructed_grid_np_hat.to_numpy()


    reconstructed_grid_np_hat[np.isnan(reconstructed_grid_np_hat)] = 0
    reconstructed_grid_np_ref[np.isnan(reconstructed_grid_np_ref)] = 0

    # Compute difference
    diff = reconstructed_grid_np_hat - reconstructed_grid_np_ref
    # Plot difference using imshow
    plt.figure(figsize=(10, 5.5)) 
    # norm = TwoSlopeNorm(vmin=np.min(diff), vcenter=0, vmax=np.max(diff))
    im = plt.imshow(diff, cmap="RdBu_r", vmin=-1.5, vmax=1.5)
    plt.colorbar(im)
    plt.title("Difference")
    plt.savefig(dir_save + "/" + sample_name + "difference.png", bbox_inches="tight")
    plt.close()


class KHINRNetModule(pl.LightningModule):
    def __init__(self,
        normalizer,
        data,
        params_data: DictConfig,
        params_model: DictConfig,
        params_optim: DictConfig,
        params_scheduler: DictConfig,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.cfg_data      = params_data
        self.cfg_model     = params_model
        self.cfg_optim     = params_optim
        self.cfg_scheduler = params_scheduler

        self.cfg_model.n_train = self.cfg_data.n_train_val[0]
        self.model      = get_model(self.cfg_model, data)
        self.optimizer  = get_optimizer(list(self.model.parameters()), self.cfg_optim)
        self.scheduler  = get_scheduler(self.optimizer, self.cfg_scheduler)
        self.criterion  = get_loss(self.cfg_optim.loss)
        self.data = data
        self.normalizer = normalizer
        self.sync_dist = torch.cuda.device_count() > 1
        self.validation_step_yhat = []
        self.validation_step_yref = []
        self.validation_step_yref_invalid = []
        self.x_valid = []
        self.x_invalid = []
        self.m_MSE = MeanSquaredError()
        self.m_PSNR = PeakSignalNoiseRatio(1)
        self.m_SSIM = StructuralSimilarityIndexMeasure()
        self.m_PSNR_ours = psnr
        self.crps = crps_ensemble
    def step(self, batch: Any):
        """
        Args:
        input_x, output_y, idx
            x    (torch.tensor) - coordinates - (b, n_points, 2 = [x, y])
            yref (torch.tensor) - gst - (b, n_points, 1)
            idx  (list) - (b, 1)
        Returns:
            loss (torch.tensor) - (1)
            yhat (torch.tensor) - (b, n_points, 1)
            yref (torch.tensor) - (b, n_points, 1)
        """
        z, x, yref, idx = batch
    
        masks = req_masks(yref)
        
        yref_valid, yref_invalid = req_samples(yref, masks)
        x_valid, x_invalid = req_samples(x, masks)

        # print("yref_valid shape",yref_valid.shape)

        y_sparse = z[:,:,2]
        masks_sparse = req_masks(y_sparse.unsqueeze(2))
        z_valid, z_invalid = req_samples(z, masks_sparse)
      
        
        yhat = self.model(x_valid, z_valid, idx)
        loss = self.criterion(yhat, yref_valid)

        return loss, yhat, yref_valid, yref_invalid, x_valid, x_invalid

    def training_step(self, batch: Any, batch_idx: int):
        loss, yhat, yref, _, _,_ = self.step(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
        self.log("train/mse", self.m_MSE(yhat, yref), sync_dist=self.sync_dist)
        return {"loss": loss}

    def validation_step(self, batch: Any, batch_idx: int):
        """
        Args:
            x (torch.tensor) - coordinates - (b, h*w, 2)
            y (torch.tensor) - temperature  - (b, h*w, 1)
            idx (int) - index  - (b, 1)
        Returns:
            loss (torch.tensor) - (1)
            yhat (torch.tensor) - (b, h*w, 1)
            yref (torch.tensor) - (b, h*w, 1)
        """
        _, yhat, yref,yref_invalid ,x_valid, x_invalid = self.step(batch)

        self.validation_step_yhat.append(yhat)
        self.validation_step_yref.append(yref)
        self.validation_step_yref_invalid.append(yref_invalid)
        self.x_valid.append(x_valid)
        self.x_invalid.append(x_invalid)

        return {"yref": yref, "yhat": yhat}

    def on_validation_epoch_end(self):
        yhats_scores = torch.cat(self.validation_step_yhat, dim=0) # Now shape is (Total_Samples, N_valid, 1)
        yrefs_scores = torch.cat(self.validation_step_yref, dim=0)
        yrefs_invalid = torch.cat(self.validation_step_yref_invalid, dim=0)
        x_valid = torch.cat(self.x_valid, dim=0)
        x_invalid = torch.cat(self.x_invalid, dim=0)
       

        yhats = torch.cat((yhats_scores, yrefs_invalid), dim = 1)
        yrefs= torch.cat((yrefs_scores, yrefs_invalid), dim = 1)
        x = torch.cat((x_valid, x_invalid), dim = 1)

        # (1) GST: 192, 288 (2) SST: 901, 1001 (3) SSH: 1440, 2880
        # shape_options = [(192, 288), (901, 1001), (180, 360), (170, 360)]

        # for h, w in shape_options:
        #     try:
        #         yhats = rearrange(yhats, 'n1 n2 (h w) c -> (n1 n2) c h w', h=h, w=w)
        #         yrefs = rearrange(yrefs, 'n1 n2 (h w) c -> (n1 n2) c h w', h=h, w=w)
        #         break  
        #     except Exception as e:
        #         continue  
        # else:
        #     raise ValueError("None of the (h, w) combinations matched the shape of the input tensors.")
        self.log("validation/mse", self.m_MSE(yhats_scores, yrefs_scores), sync_dist=self.sync_dist)
        self.log("validation/psnr", self.m_PSNR(yhats_scores, yrefs_scores), sync_dist=self.sync_dist)

        b_size = 2
        for idx in range(b_size): 
            plotSample2(toNumpy(x[idx, : ,:]), toNumpy(torch.squeeze(yhats[idx,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:])), self.cfg_model.save_dir, f"val_epoch_{self.current_epoch}_idx_{idx}", self.data)
            plot_Subsamples(toNumpy(x[idx, : ,:]), toNumpy(torch.squeeze(yhats[idx,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:])), self.cfg_model.save_dir, self.data)
            plotdiff2(toNumpy(x[idx, : ,:]), toNumpy(torch.squeeze(yhats[idx,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:])), self.cfg_model.save_dir, f"Diff_val_epoch_{self.current_epoch}_idx_{idx}")
        self.validation_step_yhat.clear()
        self.validation_step_yref.clear()
        self.validation_step_yref_invalid.clear()
        self.x_valid.clear()
        self.x_invalid.clear()
        self.validation_step_yref.clear()

    def test_step(self, batch: Any, batch_idx: int):
        _, yhat, yref,yref_invalid ,x_valid, x_invalid = self.step(batch)
        thresholds= [-0.75, -0.5, 0.5, 0.75]

        eval = Evaluator(thresholds)
        eval.evaluate(yref, yhat)
        res_dict = eval.done()

        self.log("test/pearson_corr", pearson_corr(yhat, yref))
        self.log("test/mse", self.m_MSE(yhat, yref))
        self.log("test/psnr", self.m_PSNR(yhat, yref), sync_dist=self.sync_dist)
        self.log("test/our_psnr", self.m_PSNR_ours(yhat, yref, yref.max().item()), sync_dist=self.sync_dist)
        self.log("test/CRPS", self.crps(yhat, yref).item(), sync_dist=self.sync_dist)
        self.log("test/csi", res_dict['csi'], sync_dist=self.sync_dist)
        self.log("test/far", res_dict['far'], sync_dist=self.sync_dist)
        self.log("test/avg_pod", res_dict['avg_pod'], sync_dist=self.sync_dist)
        self.log("test/hss", res_dict['hss'], sync_dist=self.sync_dist)

        yhat = toNumpy(torch.cat((yhat,yref_invalid), dim = 1))
        yref = toNumpy(torch.cat((yref,yref_invalid), dim = 1))    
        x = toNumpy(torch.cat((x_valid,x_invalid), dim = 1)) 

    
        reconstructed_grid_np_ref_f = []
        reconstructed_grid_np_hat_f = []
        for b in range(yhat.shape[0]):
            yhat_new = np.squeeze(yhat[b], axis=-1)
            yref_new = np.squeeze(yref[b], axis=-1)
            x = x[b]

            df = pd.DataFrame({
            'lat': x[:, 0],
            'lon': x[:, 1],
            'data1': yref_new, 
            'data2' : yhat_new
            })

        
            reconstructed_grid_np_ref = df.pivot(index='lat', columns='lon', values='data1')
            reconstructed_grid_np_hat = df.pivot(index='lat', columns='lon', values='data2')

            reconstructed_grid_np_ref = reconstructed_grid_np_ref.sort_index(ascending=False)
            reconstructed_grid_np_hat = reconstructed_grid_np_hat.sort_index(ascending=False)

            reconstructed_grid_np_ref = reconstructed_grid_np_ref.to_numpy()
            reconstructed_grid_np_hat = reconstructed_grid_np_hat.to_numpy()

            reconstructed_grid_np_ref_f.append(reconstructed_grid_np_ref)
            reconstructed_grid_np_hat_f.append(reconstructed_grid_np_hat)

        reconstructed_grid_np_ref_final = torch.from_numpy(np.expand_dims(np.stack(reconstructed_grid_np_ref_f), axis = 1)).float()
        reconstructed_grid_np_hat_final = torch.from_numpy(np.expand_dims(np.stack(reconstructed_grid_np_hat_f), axis = 1)).float()
        
        reconstructed_grid_np_ref_final[np.isnan(reconstructed_grid_np_ref_final)] = torch.nanmean(reconstructed_grid_np_ref_final)
        reconstructed_grid_np_hat_final[np.isnan(reconstructed_grid_np_hat_final)] = torch.nanmean(reconstructed_grid_np_hat_final)

        self.log("test/ssim", self.m_SSIM(reconstructed_grid_np_hat_final, reconstructed_grid_np_ref_final), sync_dist=self.sync_dist)
        

        # np.save(self.cfg_model.save_dir+"/predictions.npy", yhat)
        # np.save(self.cfg_model.save_dir+"/targets.npy", yref)
        
        

    def configure_optimizers(self):
        return [self.optimizer], [self.scheduler]


def pearson_corr(x, y):
    vx = x - x.mean()
    vy = y - y.mean()
    return (vx * vy).sum() / (torch.sqrt((vx**2).sum()) * torch.sqrt((vy**2).sum()) + 1e-8)


   