import torch
import pytorch_lightning as pl
import numpy as np

from typing import Any
from omegaconf import DictConfig
from einops import rearrange
from models.KHINR_net import KHINRNet
from utils import (get_optimizer, get_scheduler, get_loss, toNumpy)
from torchmetrics.regression import MeanSquaredError
from torchmetrics.image import (PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure)
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


#---------------------------------------------------------
# get model
#---------------------------------------------------------
def get_model(cfg, data_name=None):
    """
    Set model.
    Args:
        cfg: Model configuration.
    Returns:
        Model will be use for modeling.
    """
    if cfg.name == "KHINR":
        # n_block, n_mode, n_dim, n_head, n_layer, x_dim, y1_dim, y2_dim, attn, act
        model = KHINRNet(
                    n_block=cfg.n_block,
                    n_mode=cfg.n_mode,
                    n_dim=cfg.n_dim,
                    n_head=cfg.n_head,
                    n_layer=cfg.n_layer,
                    x_dim=cfg.x_dim,
                    y1_dim=cfg.y1_dim,
                    y2_dim=cfg.y2_dim,
                    attn=cfg.attn,
                    act=cfg.act,
                    data=data_name
                    )
                
    return model


#---------------------------------------------------------
# get model
#---------------------------------------------------------
def plotSample(yhat, yref, dir_save, sample_name):
    """
    Args:
        yhat (numpy.array): (b, lat, lon)
        yref (numpy.array): (b, lat, lon)
    """
    cmap = plt.get_cmap("RdBu_r")
    plt.close("all")
    vmin_ref = yref.min()
    vmax_ref= yref.max()

    vmin_hat = yhat.min()
    vmax_hat= yhat.max()

    yref = (yref- vmin_ref)/ (vmax_ref-vmin_ref)
    yhat = (yhat-vmin_hat)/(vmax_hat-vmin_hat)

    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
 
    cset1 = ax0.imshow(yref, cmap=cmap)#, vmin=vmin_ref, vmax=vmax_ref)
    ax0.set_xticks([], [])
    ax0.set_yticks([], [])
    fig.colorbar(cset1, ax=ax0)
 
    cset2 = ax1.imshow(yhat, cmap=cmap)#, vmin=vmin_ref, vmax=vmax_ref)
    ax1.set_xticks([], [])
    ax1.set_yticks([], [])
    fig.colorbar(cset2, ax=ax1)
    plt.savefig(dir_save + "/" + sample_name + ".png", bbox_inches="tight")

#---------------------------------------------------------
# plotting difference
#---------------------------------------------------------
def plotdiff(yhat, yref, dir_save, sample_name):
    colors = ["blue", "white", "red"]
    cmap = LinearSegmentedColormap.from_list("blue_red", colors)
    # Compute difference
    diff = yhat - yref
    # Plot difference using imshow
    plt.figure(figsize=(6, 5))
    im = plt.imshow(diff, cmap=cmap)
    plt.colorbar(im)
    plt.title("Difference")
    plt.savefig(dir_save + "/" + sample_name + "difference.png", bbox_inches="tight")
    plt.close()


class KHINRNetModule(pl.LightningModule):
    def __init__(self,
        normalizer,
        data,
        params_data: DictConfig,
        params_model: DictConfig,
        params_optim: DictConfig,
        params_scheduler: DictConfig,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.cfg_data      = params_data
        
        self.cfg_model     = params_model
        self.cfg_optim     = params_optim
        self.cfg_scheduler = params_scheduler

        self.cfg_model.n_train = self.cfg_data.n_train_val[0]
        self.model      = get_model(self.cfg_model, data)
        self.optimizer  = get_optimizer(list(self.model.parameters()), self.cfg_optim)
        self.scheduler  = get_scheduler(self.optimizer, self.cfg_scheduler)
        self.criterion  = get_loss(self.cfg_optim.loss)

        self.normalizer = normalizer
        self.sync_dist = torch.cuda.device_count() > 1
        self.validation_step_yhat = []
        self.validation_step_yref = []
        self.m_MSE = MeanSquaredError()
        self.m_PSNR = PeakSignalNoiseRatio(1)
        self.m_SSIM = StructuralSimilarityIndexMeasure()

    def step(self, batch: Any):
        """
        Args:
        input_x, output_y, idx
            x    (torch.tensor) - coordinates - (b, n_points, 2 = [x, y])
            yref (torch.tensor) - gst - (b, n_points, 1)
            idx  (list) - (b, 1)
        Returns:
            loss (torch.tensor) - (1)
            yhat (torch.tensor) - (b, n_points, 1)
            yref (torch.tensor) - (b, n_points, 1)
        """
        z, x, yref, idx = batch
        yhat = self.model(x, z, idx)
        loss = self.criterion(yhat, yref)
        return loss, yhat, yref

    def training_step(self, batch: Any, batch_idx: int):
        loss, yhat, yref = self.step(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True, sync_dist=self.sync_dist)
        self.log("train/mse", self.m_MSE(yhat, yref), sync_dist=self.sync_dist)
        return {"loss": loss}

    def validation_step(self, batch: Any, batch_idx: int):
        """
        Args:
            x (torch.tensor) - coordinates - (b, h*w, 2)
            y (torch.tensor) - temperature  - (b, h*w, 1)
            idx (int) - index  - (b, 1)
        Returns:
            loss (torch.tensor) - (1)
            yhat (torch.tensor) - (b, h*w, 1)
            yref (torch.tensor) - (b, h*w, 1)
        """
        _, yhat, yref = self.step(batch)
        self.validation_step_yhat.append(yhat)
        self.validation_step_yref.append(yref)
        return {"yref": yref, "yhat": yhat}

    def on_validation_epoch_end(self):
        yhats = torch.stack(self.validation_step_yhat)
        yrefs = torch.stack(self.validation_step_yref)
        # (1) GST: 192, 288 (2) SST: 901, 1001
        try:
            yhats = rearrange(yhats, 'n1 n2 (h w) c -> (n1 n2) c h w', h=192, w=288)
            yrefs = rearrange(yrefs, 'n1 n2 (h w) c -> (n1 n2) c h w', h=192, w=288)
        except:
            yhats = rearrange(yhats, 'n1 n2 (h w) c -> (n1 n2) c h w', h=901, w=1001)
            yrefs = rearrange(yrefs, 'n1 n2 (h w) c -> (n1 n2) c h w', h=901, w=1001)

        self.log("validation/mse", self.m_MSE(yhats, yrefs), sync_dist=self.sync_dist)
        self.log("validation/psnr", self.m_PSNR(yhats, yrefs), sync_dist=self.sync_dist)
        self.log("validation/ssim", self.m_SSIM(yhats, yrefs), sync_dist=self.sync_dist)

        b_size = 3
        for idx in range(b_size):
            if self.current_epoch==199:
                plotSample(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"val_epoch_{self.current_epoch}_idx_{idx}_blk_{self.cfg_model.n_block}_mod_{self.cfg_model.n_mode}_dim_{self.cfg_model.n_dim}_head_{self.cfg_model.n_head}_layer_{self.cfg_model.n_layer}")
                plotdiff(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"Diff_val_epoch_{self.current_epoch}_idx_{idx}_blk_{self.cfg_model.n_block}_mod_{self.cfg_model.n_mode}_dim_{self.cfg_model.n_dim}_head_{self.cfg_model.n_head}_layer_{self.cfg_model.n_layer}")
            else:
                plotSample(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"val_epoch_{self.current_epoch}_idx_{idx}")
                plotdiff(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"Diff_val_epoch_{self.current_epoch}_idx_{idx}")
        
        self.validation_step_yhat.clear()
        self.validation_step_yref.clear()

    def test_step(self, batch: Any, batch_idx: int):
        _, yhat, yref = self.step(batch)


        try:
            yhats = rearrange(yhat, 'n1  (h w) c -> (n1) c h w', h=192, w=288)
            yrefs = rearrange(yref, 'n1  (h w) c -> (n1 ) c h w', h=192, w=288)
        except:
            yhats = rearrange(yhat, 'n1  (h w) c -> (n1 ) c h w', h=901, w=1001)
            yrefs = rearrange(yref, 'n1  (h w) c -> (n1 ) c h w', h=901, w=1001)

        self.log("test/mse", self.m_MSE(yhat, yref))
        self.log("test/pearson_corr", pearson_corr(yhats, yrefs))
        self.log("test/psnr", self.m_PSNR(yhats, yrefs), sync_dist=self.sync_dist)
        self.log("test/ssim", self.m_SSIM(yhats, yrefs), sync_dist=self.sync_dist)
        np.save(self.cfg_model.save_dir+"/predictions.npy", toNumpy(yhat))
        np.save(self.cfg_model.save_dir+"/targets.npy", toNumpy(yref))

        # print(yhats.shape, yrefs.shape)
        # exit()
        # b_size = 2
        # for idx in range(b_size):
        #     print(f"Plotting sample for idx {idx} at epoch {self.current_epoch}")
        #     if self.current_epoch==199:
        #         plotSample(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"test_epoch_{self.current_epoch}_idx_{idx}_blk_{self.cfg_model.n_block}_mod_{self.cfg_model.n_mode}_dim_{self.cfg_model.n_dim}_head_{self.cfg_model.n_head}_layer_{self.cfg_model.n_layer}")
        #         plotdiff(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"Diff_test_epoch_{self.current_epoch}_idx_{idx}_blk_{self.cfg_model.n_block}_mod_{self.cfg_model.n_mode}_dim_{self.cfg_model.n_dim}_head_{self.cfg_model.n_head}_layer_{self.cfg_model.n_layer}")
        #     else:
        #         plotSample(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"test_epoch_{self.current_epoch}_idx_{idx}")
        #         plotdiff(toNumpy(torch.squeeze(yhats[idx,:,:,:])), toNumpy(torch.squeeze(yrefs[idx,:,:,:])), self.cfg_model.save_dir, f"Diff_test_epoch_{self.current_epoch}_idx_{idx}")
        # exit()
        

    def configure_optimizers(self):
        return [self.optimizer], [self.scheduler]


def pearson_corr(x, y):
    vx = x - x.mean()
    vy = y - y.mean()
    return (vx * vy).sum() / (torch.sqrt((vx**2).sum()) * torch.sqrt((vy**2).sum()) + 1e-8)


   