import torch
import matplotlib.pyplot as plt
import os
import numpy as np
import re
from utils import tonumpy_denormalize
from matplotlib.colors import ListedColormap


rainbow_cmap = ListedColormap(np.load('./datasets/rainbow256.npy'))


def visualize_airfoil(args, frame, data, predictions, save_path, x_normalizer, y_normalizer):
    x_train, y_train = data
    pred_x, pred_y = predictions
    num_vis = x_train.shape[0]
    nx, n = 40, 20 
    
    y_train, pred_y = y_train[:, :, :, 0], pred_y[:, :, :, 0]
    
    if x_normalizer is not None:
        x_train = x_normalizer.decode(x_train)
        y_train = y_normalizer.decode(y_train)
        pred_x = x_normalizer.decode(pred_x)
        pred_y = y_normalizer.decode(pred_y)
    
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(8, num_vis*2))
    axes = []

    for fig_idx in range(num_vis):
        axes.append(fig.add_subplot(num_vis, 3, fig_idx*3+1))
        axes[fig_idx*3].pcolormesh(x_train[fig_idx, nx:-nx, 0:n, 0], 
                                  x_train[fig_idx, nx:-nx, 0:n, 1], 
                                  y_train[fig_idx, nx:-nx, 0:n], 
                                  shading='gouraud')
        
        axes.append(fig.add_subplot(num_vis, 3, fig_idx*3+2))
        axes[fig_idx*3+1].pcolormesh(pred_x[fig_idx, nx:-nx, 0:n, 0], 
                                    pred_x[fig_idx, nx:-nx, 0:n, 1], 
                                    y_train[fig_idx, nx:-nx, 0:n], 
                                    shading='gouraud')
        
        axes.append(fig.add_subplot(num_vis, 3, fig_idx*3+3))
        axes[fig_idx*3+2].pcolormesh(x_train[fig_idx, nx:-nx, 0:n, 0], 
                                    x_train[fig_idx, nx:-nx, 0:n, 1], 
                                    pred_y[fig_idx, nx:-nx, 0:n], 
                                    shading='gouraud')
    
    axes[0].set_title('True')
    axes[1].set_title('Pred_x + True_y')
    axes[2].set_title('True_x + Pred_y')
    
    plt.tight_layout()            
    plt.savefig(save_file)
    plt.close(fig)
    plt.close('all')
    plt.clf()

def visualize_pipe(args, frame, data, predictions, save_path, x_normalizer=None, y_normalizer=None):
    x_train, y_train = data
    pred_x, pred_y = predictions
    num_vis = x_train.shape[0]
    nx, n = 40, 20

    y_train, pred_y = y_train[:, :, :, 0], pred_y[:, :, :, 0]
    
    if x_normalizer is not None:
        x_train = x_normalizer.decode(x_train)
        y_train = y_normalizer.decode(y_train)
        pred_x = x_normalizer.decode(pred_x)
        pred_y = y_normalizer.decode(pred_y)
    
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(8, num_vis*2))
    axes = []
    
    for fig_idx in range(num_vis):
        axes.append(fig.add_subplot(num_vis, 4, fig_idx*4+1))

        axes[fig_idx*4].pcolormesh(x_train[fig_idx, nx:-nx, 0:n, 0], 
                                  x_train[fig_idx, nx:-nx, 0:n, 1], 
                                  y_train[fig_idx, nx:-nx, 0:n], 
                                  shading='gouraud')
    
        
        axes.append(fig.add_subplot(num_vis, 4, fig_idx*4+2))
        axes[fig_idx*4+1].pcolormesh(pred_x[fig_idx, nx:-nx, 0:n, 0], 
                                    pred_x[fig_idx, nx:-nx, 0:n, 1], 
                                    y_train[fig_idx, nx:-nx, 0:n], 
                                    shading='gouraud')

        axes.append(fig.add_subplot(num_vis, 4, fig_idx*4+3))
        axes[fig_idx*4+2].pcolormesh(x_train[fig_idx, :, :, 0], 
                                  x_train[fig_idx, :, :, 1], 
                                  y_train[fig_idx, :, :], 
                                  shading='gouraud')

 
    
        axes.append(fig.add_subplot(num_vis, 4, fig_idx*4+4))
        axes[fig_idx*4+3].pcolormesh(x_train[fig_idx, :, :, 0], 
                                    x_train[fig_idx, :, :, 1], 
                                    pred_y[fig_idx, :, :], 
                                    shading='gouraud')    


    axes[0].set_title('True')
    axes[1].set_title('Pred_x + True_y')
    axes[2].set_title('True')
    axes[3].set_title('True_x + Pred_y')
    
    plt.tight_layout()            
    plt.savefig(save_file)
    plt.close(fig)
    plt.close('all')
    plt.clf()
    
def visualize_fwi(args, frame, data, predictions, save_path):
    
    x_train, y_train = data
    pred_x, pred_y = predictions
    
    x_train = tonumpy_denormalize(x_train, args.seis_min, args.seis_max, exp=True)
    pred_x = tonumpy_denormalize(pred_x, args.seis_min, args.seis_max, exp=True)
    y_train = tonumpy_denormalize(y_train, args.vel_min, args.vel_max, exp=False)
    pred_y = tonumpy_denormalize(pred_y, args.vel_min, args.vel_max, exp=False)
    
    
    num_vis = x_train.shape[0]  
    
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(12, num_vis*2))
    axes = []   
    aspect = pred_x.shape[2]/pred_x.shape[1]
    
    for fig_idx in range(num_vis):
        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+1))
        axes[fig_idx*6].imshow(pred_y[fig_idx, :, :, 0], cmap=rainbow_cmap, vmax = np.max(y_train[fig_idx, :, :, 0]), vmin = np.min(y_train[fig_idx, :, :, 0]))  
        
        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+2))
        axes[fig_idx*6+1].imshow(pred_x[fig_idx, :, :, 0], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)
        
        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+3))
        axes[fig_idx*6+2].imshow(pred_x[fig_idx, :, :, 1], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)
        
        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+4))
        axes[fig_idx*6+3].imshow(pred_x[fig_idx, :, :, 2], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)
        
        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+5))
        axes[fig_idx*6+4].imshow(pred_x[fig_idx, :, :, 3], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)

        axes.append(fig.add_subplot(num_vis, 6, fig_idx*6+6))
        axes[fig_idx*6+5].imshow(pred_x[fig_idx, :, :, 4], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)
        
        
    axes[0].set_title('Velocity')
    axes[1].set_title('Seis_1')
    axes[2].set_title('Seis_2')
    axes[3].set_title('Seis_3')
    axes[4].set_title('Seis_4')
    axes[5].set_title('Seis_5')
    
    plt.tight_layout()            
    plt.savefig(save_file)
    plt.close(fig)
    plt.close('all')
    plt.clf()

        
    
    
 
def visualize_results(args, frame, data, predictions, save_path, x_normalizer=None, y_normalizer=None):
    if args.dataset == 'navier_stokes':
        visualize_navier_stokes(args, frame, data, predictions, save_path, x_normalizer, y_normalizer)
    elif args.dataset == 'ns_range':
        visualize_navier_stokes_range(args, frame, data, predictions, save_path, x_normalizer, y_normalizer)
    elif args.dataset == 'airfoil':
        visualize_airfoil(args, frame, data, predictions, save_path, x_normalizer, y_normalizer)
    elif args.dataset == 'pipe':
        visualize_pipe(args, frame, data, predictions, save_path, x_normalizer, y_normalizer)
    elif args.dataset == 'conv':
        visualize_conv(args, frame, data, predictions, save_path, x_normalizer)
    elif args.dataset == 'helm':
        visualize_helm(args, frame, data, predictions, save_path, x_normalizer)
    elif args.dataset == 'ns3d' and not args.interp:
        visualize_navier_stokes_3d(args, frame, data, predictions, save_path, x_normalizer)
    elif args.dataset == 'ns3d' and args.interp:
        visualize_navier_stokes_3d_interp(args, frame, data, predictions, save_path, x_normalizer)
    elif args.dataset == 'ks':
        visualize_ks(args, frame, data, predictions, save_path, x_normalizer)
    elif args.dataset == 'fwi':
        visualize_fwi(args, frame, data, predictions, save_path)
    elif args.dataset == 'ns3d_twoway':
        visualize_navier_stokes_3d_twoway(args, frame, data, predictions, save_path, x_normalizer, y_normalizer)
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")

def visualize_conv(args, frame, target_data, pred_data, save_path, normalizer=None):

    num_vis = target_data.shape[0]
    
    if normalizer is not None:
        target_data = normalizer.decode(target_data)
        pred_data = normalizer.decode(pred_data)
    
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(8, 2*num_vis))
    
    vmin = 0
    vmax = 2
    
    for i in range(num_vis):

        plt.subplot(num_vis, 2, i*2 + 1)
        im = plt.imshow(target_data[i,:,:,0].T, interpolation='nearest', cmap='rainbow', vmin=vmin, vmax=vmax,
                        origin='lower', aspect='auto')
        if i == 0: plt.title('True')

        plt.subplot(num_vis, 2, i*2 + 2)
        im = plt.imshow(pred_data[i,:,:,0].T, interpolation='nearest', cmap='rainbow', vmin=vmin, vmax=vmax,
                        origin='lower', aspect='auto')
        if i == 0: plt.title('Prediction')
        
    plt.tight_layout()
    plt.savefig(save_file, bbox_inches='tight', dpi=300)
    plt.close() 


def visualize_helm(args, frame, target_data, pred_data, save_path, normalizer=None):

    num_vis = target_data.shape[0]
    
    if normalizer is not None:
        target_data = normalizer.decode(target_data)
        pred_data = normalizer.decode(pred_data)
    
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(4, 2*num_vis))
    
    vmin = -1
    vmax = 1
    
    for i in range(num_vis):

        plt.subplot(num_vis, 2, i*2 + 1)
        im = plt.imshow(target_data[i,:,:,0].T, interpolation='nearest', cmap='rainbow', vmin=vmin, vmax=vmax,
                        origin='lower', aspect='auto')
        if i == 0: plt.title('True')

        plt.subplot(num_vis, 2, i*2 + 2)
        im = plt.imshow(pred_data[i,:,:,0].T, interpolation='nearest', cmap='rainbow', vmin=vmin, vmax=vmax,
                        origin='lower', aspect='auto')
        if i == 0: plt.title('Prediction')
        
    plt.tight_layout()
    plt.savefig(save_file, bbox_inches='tight', dpi=300)
    plt.close() 

   
def visualize_navier_stokes_3d(args, frame, target_data, pred_data, save_path, x_normalizer):

    num_time = 10
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(12,20))
    axes = []
    
    if x_normalizer is not None:
        target_data = x_normalizer.decode(target_data)
        pred_data = x_normalizer.decode(pred_data)
        
    target_data = target_data[0,:,:,:,0]
    pred_data = pred_data[0,:,:,:,0]
    for t_idx in range(num_time):
        axes.append(fig.add_subplot(num_time,6,t_idx*6+1))
        axes[t_idx*6].imshow(target_data[:,:,t_idx], cmap='jet', 
                              interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,6,t_idx*6+2))
        axes[t_idx*6+1].imshow(target_data[:,:,t_idx+10], cmap='jet',
                                interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,6,t_idx*6+3))
        axes[t_idx*6+2].imshow(target_data[:,:,t_idx+20], cmap='jet',
                                interpolation='nearest')
        
        # Predicted output
        axes.append(fig.add_subplot(num_time,6,t_idx*6+4))
        axes[t_idx*6+3].imshow(pred_data[:,:,t_idx], cmap='jet',
                                interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,6,t_idx*6+5))
        axes[t_idx*6+4].imshow(pred_data[:,:,t_idx+10], cmap='jet',
                                interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,6,t_idx*6+6))
        axes[t_idx*6+5].imshow(pred_data[:,:,t_idx+20], cmap='jet',
                                interpolation='nearest')
    
    axes[0].set_title(f'T = {args.t1} ~ {args.t1+10} (True)')
    axes[1].set_title(f'T = {args.t1+10} ~ {args.t1+20} (Pred)')
    axes[2].set_title(f'T = {args.t1+20} ~ {args.t1+30} (True)')
    axes[3].set_title(f'T = {args.t1} ~ {args.t1+10} (Pred)')
    axes[4].set_title(f'T = {args.t1+10} ~ {args.t1+20} (Pred)')
    axes[5].set_title(f'T = {args.t1+20} ~ {args.t1+30} (Pred)')
    
    plt.tight_layout()            
    plt.savefig(save_file)
    plt.close(fig)
    plt.close('all')
    plt.clf()
    


    
def visualize_ks(args, frame, target_data, pred_data, save_path, normalizer=None):
    num_vis = target_data.shape[0]
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")

    if normalizer is not None:
        target_data = normalizer.decode(target_data)
        pred_data = normalizer.decode(pred_data)

    fig = plt.figure(figsize=(8, 2 * num_vis))

    for i in range(num_vis):
        vmax, vmin = torch.max(target_data[i]), torch.min(target_data[i])
        plt.subplot(num_vis, 2, i*2 + 1)
        im = plt.imshow(
            target_data[i, :, :, 0].T,  
            origin="lower",
            aspect="auto",
            cmap="viridis",
            vmin=vmin,
            vmax=vmax,
            extent=[-1, 1, 0, target_data.shape[2]] 
        )
        plt.xlim(-1, 1)
        if i == 0:
            plt.title("True")
        plt.colorbar(im, fraction=0.046, pad=0.04)

        plt.subplot(num_vis, 2, i*2 + 2)
        im = plt.imshow(
            pred_data[i, :, :, 0].T,
            origin="lower",
            aspect="auto",
            cmap="viridis",
            vmin=vmin,
            vmax=vmax,
            extent=[-1, 1, 0, pred_data.shape[2]]
        )
        plt.xlim(-1, 1)
        if i == 0:
            plt.title("Prediction")
        plt.colorbar(im, fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.savefig(save_file, bbox_inches='tight', dpi=300)
    plt.close()
    


   
def visualize_navier_stokes_3d_twoway(args, frame, target_data, pred_data, save_path, x_normalizer, y_normalizer):

    num_time = 10
    save_file = os.path.join(save_path, f"image_{frame:03d}.png")
    
    fig = plt.figure(figsize=(8,20))
    axes = []
    
    x_data, y_data = target_data
    x_pred, y_pred = pred_data
    
    if x_normalizer is not None:
        target_data_x = x_normalizer.decode(x_data)
        pred_data_x = x_normalizer.decode(x_pred)
        
    if y_normalizer is not None:
        target_data_y = y_normalizer.decode(y_data)
        pred_data_y = y_normalizer.decode(y_pred)
        
    target_data_x = target_data_x[0,:,:,:,0]
    pred_data_x = pred_data_x[0,:,:,:,0]
    target_data_y = target_data_y[0,:,:,:,0]
    pred_data_y = pred_data_y[0,:,:,:,0]

    for t_idx in range(num_time):
        axes.append(fig.add_subplot(num_time,4,t_idx*4+1))
        axes[t_idx*4].imshow(target_data_x[:,:,t_idx], cmap='jet', 
                              interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,4,t_idx*4+2))
        axes[t_idx*4+1].imshow(pred_data_x[:,:,t_idx], cmap='jet',
                                interpolation='nearest')
        
        axes.append(fig.add_subplot(num_time,4,t_idx*4+3))
        axes[t_idx*4+2].imshow(target_data_y[:,:,t_idx], cmap='jet',
                                interpolation='nearest')
        
        # Predicted output
        axes.append(fig.add_subplot(num_time,4,t_idx*4+4))
        axes[t_idx*4+3].imshow(pred_data_y[:,:,t_idx], cmap='jet',
                                interpolation='nearest')
        
    axes[0].set_title(f'T = {args.t1} ~ {args.t1+10} (True)')
    axes[1].set_title(f'T = {args.t1} ~ {args.t1+10} (Pred)')
    axes[2].set_title(f'T = {args.t1+10} ~ {args.t1+20} (True)')
    axes[3].set_title(f'T = {args.t1+10} ~ {args.t1+20} (Pred)')
    
    plt.tight_layout()            
    plt.savefig(save_file)
    plt.close(fig)
    plt.close('all')
    plt.clf()
    
