import matplotlib.pyplot as plt
from torchvision.transforms import transforms
import numpy as np
import torch
import os
from forward_process import *
from dataset import *
from sample import *

from noise import *

def visualalize_rgb(image1,image2,image):
    plt.figure(figsize=(11,11))
    plt.subplot(1, 5, 1).axis('off')
    plt.subplot(1, 5, 2).axis('off')
    plt.subplot(1, 5, 3).axis('off')
    plt.subplot(1, 5, 4).axis('off')
    plt.subplot(1, 5, 5).axis('off')
    plt.subplot(1, 5, 1)
    plt.imshow(show_tensor_image(image[:,0,:,:].unsqueeze(1)))
    plt.title('r')
    plt.subplot(1, 5, 2)
    plt.imshow(show_tensor_image(image[:,1,:,:].unsqueeze(1)))
    plt.title('g')
    plt.subplot(1, 5, 3)
    plt.imshow(show_tensor_image(image[:,2,:,:].unsqueeze(1)))
    plt.title('b')
    plt.subplot(1, 5, 4)
    plt.imshow(show_tensor_image(image1))
    plt.title('image1')
    plt.subplot(1, 5, 5)
    plt.imshow(show_tensor_image(image2))
    plt.title('image2')

    k = 0
    while os.path.exists('results/rgb{}.png'.format(k)):
        k += 1
    plt.savefig('results/rgb{}.png'.format(k))
    plt.close()

def visualalize_distance(output, condition, target):
    plt.figure(figsize=(11,11))
    plt.subplot(1, 3, 1).axis('off')
    plt.subplot(1, 3, 2).axis('off')
    plt.subplot(1, 3, 3).axis('off')
    # plt.subplot(1, 6, 4).axis('off')
    # plt.subplot(1, 6, 5).axis('off')
    # plt.subplot(1, 6, 6).axis('off')
  

    plt.subplot(1, 3, 1)
    plt.imshow(show_tensor_image(output))
    plt.title('input image')
    

    plt.subplot(1, 3, 2)
    plt.imshow(show_tensor_image(condition))
    plt.title('condition image')

    plt.subplot(1, 3, 3)
    plt.imshow(show_tensor_image(target))
    plt.title('generated image')


    k = 0
    while os.path.exists('results/heatmap{}.png'.format(k)):
        k += 1
    plt.savefig('results/heatmap{}.png'.format(k))
    plt.close()

def visualize_reconstructed(input, data,s):

    fig, axs = plt.subplots(int(len(data)/5),6)
    row = 0
    col = 1
    axs[0,0].imshow(show_tensor_image(input))
    axs[0, 0].get_xaxis().set_visible(False)
    axs[0, 0].get_yaxis().set_visible(False)
    axs[0,0].set_title('input')
    for i, img in enumerate(data):
        axs[row, col].imshow(show_tensor_image(img))
        axs[row, col].get_xaxis().set_visible(False)
        axs[row, col].get_yaxis().set_visible(False)
        axs[row, col].set_title(str(i))
        col += 1
        if col == 6:
            row += 1
            col = 0
    col = 6
    row = int(len(data)/5)
    remain = col * row - len(data) -1
    for j in range(remain):
        col -= 1
        axs[row-1, col].remove()
        axs[row-1, col].get_xaxis().set_visible(False)
        axs[row-1, col].get_yaxis().set_visible(False)
        
    
        
    plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.4,
                    hspace=0.4)
    k = 0

    while os.path.exists(f'results/reconstructed{k}{s}.png'):
        k += 1
    plt.savefig(f'results/reconstructed{k}{s}.png')
    plt.close()



def visualize(image, noisy_image, GT, pred_mask, anomaly_map, category, config, orig_img, step_list, filename_list, anomaly_map_recon_list, anomaly_map_latent_list, anomaly_map_feature_list,KNN_feature_list) :
    for idx, img in enumerate(image):
        
        if config.model.visual_all:
            # plt.imshow(show_tensor_image(orig_img[idx]))
            # plt.tight_layout()
            # plt.axis("off")
            # plt.savefig('results/{}/{}sample{}_save_all_clear.png'.format(category,category,idx))
            # plt.close()
            if config.model.dynamic_steps:
                plt.imsave('results/{}/{}sample{}_dyn_save_all_clear.png'.format(category,category,idx), show_tensor_image(orig_img[idx]))
                plt.imsave('results/{}/{}sample{}_dyn_save_all_recon.png'.format(category,category,idx), show_tensor_image(noisy_image[idx]))
                
                fig, ax = plt.subplots()
                ax.imshow(show_tensor_image(anomaly_map[idx]))
                ax.axis('off')  # This hides the axes
                fig.savefig('results/{}/{}sample{}_dyn_save_all_heatmap_combined.png'.format(category,category,idx), bbox_inches='tight', pad_inches=0)
                plt.close(fig)
                  
                fig, ax = plt.subplots()
                ax.imshow(show_tensor_image(anomaly_map_latent_list[idx]))
                ax.axis('off')  # This hides the axes
                fig.savefig('results/{}/{}sample{}_dyn_save_all_heatmap_latent.png'.format(category,category,idx), bbox_inches='tight', pad_inches=0)
                plt.close(fig)
                    
                fig, ax = plt.subplots()
                ax.imshow(show_tensor_image(anomaly_map_feature_list[idx]))
                ax.axis('off')  # This hides the axes
                fig.savefig('results/{}/{}sample{}_dyn_save_all_heatmap_feature.png'.format(category,category,idx), bbox_inches='tight', pad_inches=0)
                plt.close(fig)
                
                fig, ax = plt.subplots()
                ax.imshow(show_tensor_image(KNN_feature_list[idx]))
                ax.axis('off')  # This hides the axes
                fig.savefig('results/{}/{}sample{}_dyn_save_all_heatmap_KNN.png'.format(category,category,idx), bbox_inches='tight', pad_inches=0)
                plt.close(fig)
                    
                plt.imsave('results/{}/{}sample{}_dyn_save_all_GT_mask.png'.format(category,category,idx), show_tensor_mask(GT[idx],config))
                plt.imsave('results/{}/{}sample{}_dyn_save_all_pred_mask.png'.format(category,category,idx), show_tensor_mask(pred_mask[idx],config))
            else:
                plt.imsave('results/{}/{}sample{}_{}_save_all_clear.png'.format(category,category,idx,config.model.skip2), show_tensor_image(orig_img[idx]))
                plt.imsave('results/{}/{}sample{}_{}_save_all_recon.png'.format(category,category,idx,config.model.skip2), show_tensor_image(noisy_image[idx]))
                plt.imsave('results/{}/{}sample{}_{}_save_all_GT_mask.png'.format(category,category,idx,config.model.skip2), show_tensor_mask(GT[idx],config))
                plt.imsave('results/{}/{}sample{}_{}_save_all_pred_mask.png'.format(category,category,idx,config.model.skip2), show_tensor_mask(pred_mask[idx],config))

            
            # plt.imshow(show_tensor_image(noisy_image[idx]))
            # plt.tight_layout()
            # plt.axis("off")
            # plt.savefig('results/{}/{}sample{}_save_all_recon.png'.format(category,category,idx))
            # plt.close()
            

            
            # plt.imshow(show_tensor_mask(GT[idx]))
            # plt.tight_layout()
            # plt.axis("off")
            # plt.savefig('results/{}/{}sample{}_save_all_GT_mask.png'.format(category,category,idx))
            # plt.close()
            

            
            # plt.imshow(show_tensor_image(anomaly_map[idx]))
            # plt.tight_layout()
            # plt.axis("off")
            # plt.savefig('results/{}/{}sample{}_save_all_anomap.png'.format(category,category,idx))
            # plt.close()
            

        
        if config.model.latent and not config.model.latent_backbone == "VAE":
            plt.figure(figsize=(11,11))
            
            
            plt.subplot(1, 2, 1).axis('off')
            plt.subplot(1, 2, 2).axis('off')
            plt.subplot(1, 2, 1)
            plt.imshow(show_tensor_image(orig_img[idx]))
            plt.title('clear image')
            if config.model.dynamic_steps:
                if int(step_list[idx]) >= 7:
                    plt.savefig('results/{}/{}sample{}_dynamic_big.png'.format(category,category,idx))
                else:
                    plt.savefig('results/{}/{}sample{}_dynamic.png'.format(category,category,idx))
            else:
                plt.savefig('results/{}/{}sample{}.png'.format(category,category,idx))
            
            plt.close()
            
        else:
            plt.figure(figsize=(11,11))
            
            
            plt.subplot(1, 2, 1).axis('off')
            plt.subplot(1, 2, 2).axis('off')
            plt.subplot(1, 2, 1)
            plt.imshow(show_tensor_image(image[idx]))
            plt.title(f'clear image {filename_list[idx]}')

            plt.subplot(1, 2, 2)

            plt.imshow(show_tensor_image(noisy_image[idx]))
            plt.title('reconstructed image')
            if config.model.dynamic_steps:
                if int(step_list[idx]) >= 7:
                    if config.model.dynamic_condition:
                        plt.savefig('results/{}/{}sample{}_dynamic_big_step{}_dyn_con.png'.format(category,category,idx,config.model.skip2))
                    else:
                        plt.savefig('results/{}/{}sample{}_dynamic_big_step{}.png'.format(category,category,idx,config.model.skip2))
                else:
                    if config.model.dynamic_condition:
                        plt.savefig('results/{}/{}sample{}_dynamic_step{}_dyn_con.png'.format(category,category,idx,config.model.skip2))
                    else:
                        plt.savefig('results/{}/{}sample{}_dynamic_step{}.png'.format(category,category,idx,config.model.skip2))
            else:
                plt.savefig('results/{}/{}sample{}_step{}.png'.format(category,category,idx,config.model.skip2))
            plt.close()

        if config.model.distance_metric_eval == "combined":
            plt.figure(figsize=(15,11))
            plt.subplot(1, 6, 1).axis('off')
            plt.subplot(1, 6, 2).axis('off')
            plt.subplot(1, 6, 3).axis('off')
            plt.subplot(1, 6, 4).axis('off')
            plt.subplot(1, 6, 5).axis('off')
            plt.subplot(1, 6, 6).axis('off')

            plt.subplot(1, 6, 1)
            plt.imshow(show_tensor_mask(GT[idx],config))
            plt.title('ground truth')


            plt.subplot(1, 6, 2)
            plt.imshow(show_tensor_mask(pred_mask[idx], config))
            plt.title('normal' if torch.max(pred_mask[idx]) == 0 else 'abnormal', color="g" if torch.max(pred_mask[idx]) == 0 else "r")

            plt.subplot(1, 6, 3)
            plt.imshow(show_tensor_image(anomaly_map[idx]))
            plt.title('heat map combined')
            
            
            plt.subplot(1, 6, 4)
            plt.imshow(show_tensor_image(anomaly_map_latent_list[idx]))
            plt.title('heat map latent')
            
            plt.subplot(1, 6, 5)
            plt.imshow(show_tensor_image(anomaly_map_feature_list[idx]))
            plt.title('heat map feature')
            
            if config.model.dynamic_steps:
                plt.subplot(1, 6, 6)
                plt.imshow(show_tensor_image(KNN_feature_list[idx]))
                plt.title('KNN feature')
            else:
                plt.subplot(1, 6, 6)
                plt.imshow(show_tensor_image(anomaly_map_feature_list[idx]))
                plt.title('heat map feature double ignore')
            
        else:
            plt.figure(figsize=(11,11))
            plt.subplot(1, 4, 1).axis('off')
            plt.subplot(1, 4, 2).axis('off')
            plt.subplot(1, 4, 3).axis('off')
            plt.subplot(1, 4, 4).axis('off')

            plt.subplot(1, 4, 1)
            plt.imshow(show_tensor_mask(GT[idx],config))
            plt.title('ground truth')


            plt.subplot(1, 4, 2)
            plt.imshow(show_tensor_mask(pred_mask[idx], config))
            plt.title('normal' if torch.max(pred_mask[idx]) == 0 else 'abnormal', color="g" if torch.max(pred_mask[idx]) == 0 else "r")

            plt.subplot(1, 4, 3)
            plt.imshow(show_tensor_image(anomaly_map[idx]))
            plt.title('heat map')
            
            plt.subplot(1, 4, 4)
            plt.imshow(show_tensor_image(anomaly_map_recon_list[idx]))
            plt.title('heat map recon')
        
        if config.model.dynamic_steps:
                if int(step_list[idx]) >= 7:
                    if config.model.dynamic_condition:
                        plt.savefig('results/{}/{}sample{}_dynamic_big_heatmap_step{}_dyn_con.png'.format(category,category,idx,config.model.skip2))
                    else:
                        plt.savefig('results/{}/{}sample{}_dynamic_big_heatmap_step{}.png'.format(category,category,idx,config.model.skip2))
                else:
                    if config.model.dynamic_condition:
                        plt.savefig('results/{}/{}sample{}_dynamic_heatmap_step{}_dyn_con.png'.format(category,category,idx,config.model.skip2))
                    else:
                        plt.savefig('results/{}/{}sample{}_dynamic_heatmap_step{}.png'.format(category,category,idx,config.model.skip2))
        else:
            plt.savefig('results/{}/{}sample{}_heatmap_step{}.png'.format(category,category,idx,config.model.skip2))
       
        plt.close()



def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / (2)),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
     #   transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    return reverse_transforms(image)

def show_tensor_mask(image, config):
    if config.model.visual_all:
        reverse_transforms = transforms.Compose([
        # transforms.Lambda(lambda t: (t + 1) / 2),
            transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
            transforms.Lambda(lambda t: t.squeeze(2)),
        # transforms.Lambda(lambda t: t * 255.),
            transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
        #   transforms.ToPILImage(),
        ])
    else:
        reverse_transforms = transforms.Compose([
        # transforms.Lambda(lambda t: (t + 1) / 2),
            transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        # transforms.Lambda(lambda t: t * 255.),
            transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
        #   transforms.ToPILImage(),
        ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    return reverse_transforms(image)
        

