import os
import wandb
from PIL import Image
import torch
import numpy as np
from typing import List, Optional, Tuple
from visualizing.visualize_utils import numpy_to_img, VIZ_IMAGE_SIZE

import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
from sklearn.metrics import confusion_matrix


def visualize_dist_pred(
    batch_obs_images: np.ndarray,
    batch_goal_images: np.ndarray,
    batch_dist_preds: np.ndarray,
    batch_dist_labels: np.ndarray,
    eval_type: str,
    save_folder: str,
    epoch: int,
    num_images_preds: int = 8,
    use_wandb: bool = True,
    display: bool = False,
    rounding: int = 4,
    dist_error_threshold: float = 3.0,
):
    """
    Visualize the distance classification predictions and labels for an observation-goal image pair.

    Args:
        batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
        batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels]
        batch_dist_preds (np.ndarray): batch of distance predictions [batch_size]
        batch_dist_labels (np.ndarray): batch of distance labels [batch_size]
        eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
        epoch (int): current epoch number
        num_images_preds (int): number of images to visualize
        use_wandb (bool): whether to use wandb to log the images
        save_folder (str): folder to save the images. If None, will not save the images
        display (bool): whether to display the images
        rounding (int): number of decimal places to round the distance predictions and labels
        dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes)
    """
    visualize_path = os.path.join(
        save_folder,
        "visualize",
        eval_type,
        f"epoch{epoch}",
        "dist_classification",
    )
    if not os.path.isdir(visualize_path):
        os.makedirs(visualize_path)
    assert (
        len(batch_obs_images)
        == len(batch_goal_images)
        == len(batch_dist_preds)
        == len(batch_dist_labels)
    )
    batch_size = batch_obs_images.shape[0]
    wandb_list = []
    for i in range(min(batch_size, num_images_preds)):
        dist_pred = np.round(batch_dist_preds[i], rounding)
        dist_label = np.round(batch_dist_labels[i], rounding)
        obs_image = numpy_to_img(batch_obs_images[i])
        goal_image = numpy_to_img(batch_goal_images[i])

        save_path = None
        if save_folder is not None:
            save_path = os.path.join(visualize_path, f"{i}.png")
        text_color = "black"
        if abs(dist_pred - dist_label) > dist_error_threshold:
            text_color = "red"

        display_distance_pred(
            [obs_image, goal_image],
            ["Observation", "Goal"],
            dist_pred,
            dist_label,
            text_color,
            save_path,
            display,
        )
        if use_wandb:
            wandb_list.append(wandb.Image(save_path))
    if use_wandb:
        wandb.log({f"{eval_type}_dist_prediction": wandb_list})


def visualize_dist_pairwise_pred(
    batch_obs_images: np.ndarray,
    batch_close_images: np.ndarray,
    batch_far_images: np.ndarray,
    batch_close_preds: np.ndarray,
    batch_far_preds: np.ndarray,
    batch_close_labels: np.ndarray,
    batch_far_labels: np.ndarray,
    eval_type: str,
    save_folder: str,
    epoch: int,
    num_images_preds: int = 8,
    use_wandb: bool = True,
    display: bool = False,
    rounding: int = 4,
):
    """
    Visualize the distance classification predictions and labels for an observation-goal image pair.

    Args:
        batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
        batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels]
        batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels]
        batch_close_preds (np.ndarray): batch of close predictions [batch_size]
        batch_far_preds (np.ndarray): batch of far predictions [batch_size]
        batch_close_labels (np.ndarray): batch of close labels [batch_size]
        batch_far_labels (np.ndarray): batch of far labels [batch_size]
        eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
        save_folder (str): folder to save the images. If None, will not save the images
        epoch (int): current epoch number
        num_images_preds (int): number of images to visualize
        use_wandb (bool): whether to use wandb to log the images
        display (bool): whether to display the images
        rounding (int): number of decimal places to round the distance predictions and labels
    """
    visualize_path = os.path.join(
        save_folder,
        "visualize",
        eval_type,
        f"epoch{epoch}",
        "dist_classification",
    )
    if not os.path.isdir(visualize_path):
        os.makedirs(visualize_path)
    assert (
        len(batch_obs_images)
        == len(batch_close_images)
        == len(batch_far_images)
        == len(batch_close_preds)
        == len(batch_far_preds)
        == len(batch_close_labels)
        == len(batch_far_labels)
    )
    batch_size = batch_obs_images.shape[0]
    wandb_list = []
    for i in range(min(batch_size, num_images_preds)):
        close_dist_pred = np.round(batch_close_preds[i], rounding)
        far_dist_pred = np.round(batch_far_preds[i], rounding)
        close_dist_label = np.round(batch_close_labels[i], rounding)
        far_dist_label = np.round(batch_far_labels[i], rounding)
        obs_image = numpy_to_img(batch_obs_images[i])
        close_image = numpy_to_img(batch_close_images[i])
        far_image = numpy_to_img(batch_far_images[i])

        save_path = None
        if save_folder is not None:
            save_path = os.path.join(visualize_path, f"{i}.png")

        if close_dist_pred < far_dist_pred:
            text_color = "black"
        else:
            text_color = "red"

        display_distance_pred(
            [obs_image, close_image, far_image],
            ["Observation", "Close Goal", "Far Goal"],
            f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}",
            f"close_label = {close_dist_label}, far_label = {far_dist_label}",
            text_color,
            save_path,
            display,
        )
        if use_wandb:
            wandb_list.append(wandb.Image(save_path))
    if use_wandb:
        wandb.log({f"{eval_type}_pairwise_classification": wandb_list})


def display_distance_pred(
    imgs: list,
    titles: list,
    dist_pred: float,
    dist_label: float,
    text_color: str = "black",
    save_path: Optional[str] = None,
    display: bool = False,
):
    plt.figure()
    fig, ax = plt.subplots(1, len(imgs))

    plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color)

    for axis, img, title in zip(ax, imgs, titles):
        axis.imshow(img)
        axis.set_title(title)
        axis.xaxis.set_visible(False)
        axis.yaxis.set_visible(False)

    # make the plot large
    fig.set_size_inches((18.5 / 3) * len(imgs), 10.5)

    if save_path is not None:
        fig.savefig(
            save_path,
            bbox_inches="tight",
        )
    if not display:
        plt.close(fig)


# def dist_confusion_matrix(
#     model,
#     eval_loader,
#     device,
#     project_folder: str,
#     plot_name: str,
#     distance_categories: List[str],
#     regress_distance: bool = False,
#     num_images_log: int = 32,
#     use_wandb: bool = True,
#     is_vibing: bool = False,
# ):
#     if regress_distance:
#         bins = []
#         for i in range(1, len(distance_categories)):
#             bins.append(np.mean([distance_categories[i - 1], distance_categories[i]]))
#         convert_pred = (lambda bins: lambda pred: np.digitize(pred, bins=bins))(bins)
#         whole_bins = [-np.inf] + bins + [np.inf]
#         pred_categories = [
#             f"{whole_bins[i - 1]}-{whole_bins[i]}" for i in range(1, len(whole_bins))
#         ]
#         value_to_index = {}
#         for i in range(len(distance_categories)):
#             value_to_index[distance_categories[i]] = i

#         convert_label = convert_pred
#         # def convert_label(x):
#         #     # for loop through the entire array and change labels
#         #     with np.nditer(x, op_flags=["readwrite"]) as it:
#         #         for z in it:
#         #             z[...] = value_to_index[int(z)]
#         #     return x

#     else:
#         convert_pred = lambda pred: np.argmax(pred, axis=1)
#         pred_categories = distance_categories
#         convert_label = lambda x: x
#     examples_folder = os.path.join(project_folder, "conf_matrix", plot_name)
#     if not os.path.isdir(examples_folder):
#         os.makedirs(examples_folder)

#     preds = []
#     labels = []
#     image_list = {}
#     for label in distance_categories:
#         image_list[label] = {}
#         label_folder = os.path.join(examples_folder, f"label_{label}")
#         for pred in pred_categories:
#             pred_folder = os.path.join(label_folder, f"pred_{pred}")
#             if not os.path.isdir(pred_folder):
#                 os.makedirs(pred_folder)
#             image_list[label][pred] = []

#     with torch.no_grad():
#         for data in eval_loader:
#             (
#                 obs_images,
#                 goal_images,
#                 transf_obs_images,
#                 transf_goal_images,
#                 dist_labels,
#             ) = data[:5]
#             transf_obs_images = transf_obs_images.to(device)
#             transf_goal_images = transf_goal_images.to(device)
#             dist_labels = dist_labels.to(device)

#             dist_pred, _ = model(transf_obs_images, transf_goal_images)
#             if is_vibing:
#                 dist_pred = dist_pred.mean
#             preds.append(convert_pred(to_numpy(dist_pred)))
#             labels.append(convert_label(to_numpy(dist_labels)))
#             i = 0

#             for pred, label in zip(preds[-1], labels[-1]):
#                 equal = int(pred) == int(label)
#                 pred = pred_categories[int(pred)]
#                 label = distance_categories[int(label)]

#                 if len(image_list[label][pred]) < num_images_log:
#                     image_path = os.path.join(
#                         examples_folder,
#                         f"label_{label}",
#                         f"pred_{pred}",
#                         f"{len(image_list[label][pred])}.png",
#                     )
#                     if not equal:
#                         text_color = "red"
#                     else:
#                         text_color = "black"
#                     obs_image = Image.fromarray(
#                         np.transpose(np.uint8(255 * to_numpy(obs_images[i])), (1, 2, 0))
#                     )
#                     goal_image = Image.fromarray(
#                         np.transpose(
#                             np.uint8(255 * to_numpy(goal_images[i])), (1, 2, 0)
#                         )
#                     )
#                     if regress_distance:
#                         display_pred = np.round(to_numpy(dist_pred)[i], 4)
#                     else:
#                         display_pred = pred
#                     display_distance_pred(
#                         [obs_image, goal_image],
#                         display_pred,
#                         label,
#                         image_path,
#                         text_color,
#                     )
#                     image_list[label][pred].append(wandb.Image(image_path))
#                 i += 1

#     if use_wandb:
#         for label in distance_categories:
#             for pred in pred_categories:
#                 wandb.log(
#                     {f"{plot_name}_label_{label}_pred_{pred}": image_list[label][pred]}
#                 )

#     str_distance_cats = []
#     for i in range(len(pred_categories)):
#         str_distance_cats.append(str(pred_categories[i]))

#     labels = np.concatenate(labels, axis=0)
#     preds = np.concatenate(preds, axis=0)
#     conf_mat_path = generate_conf_mat(
#         pred=preds,
#         y_true=labels,
#         class_labels=str_distance_cats,
#         folder_path=examples_folder,
#         title=plot_name,
#     )
#     if use_wandb:
#         wandb.log({plot_name + "conf_mat": wandb.Image(conf_mat_path)})


# def generate_conf_mat(
#     pred: List[int],
#     y_true: List[int],
#     class_labels: str,
#     folder_path: str,
#     title: str,
#     normalize: Optional[bool] = None,
#     figsize: Optional[Tuple[int]] = (15, 12),
# ):
#     conf_mat = confusion_matrix(y_true, pred, normalize=normalize)
#     df_conf_mat = pd.DataFrame(conf_mat, index=class_labels, columns=class_labels)
#     if not os.path.exists(folder_path):
#         os.makedirs(folder_path)
#     df_conf_mat.to_csv(os.path.join(folder_path, title))
#     plt.figure(figsize=figsize)
#     sn.heatmap(df_conf_mat, annot=True)
#     plt.title(title)
#     plt.xlabel("Predicted")
#     plt.ylabel("Labels")
#     path = os.path.join(folder_path, title)
#     plt.savefig(path)
#     plt.close()
#     return path + ".png"
