import math
import torch
import os
import shutil

import imageio
import numpy as np

from PIL import Image, ImageDraw

def create_log_gaussian(mean, log_std, t):
    quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
    l = mean.shape
    log_z = log_std
    z = l[-1] * math.log(2 * math.pi)
    log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
    return log_p

def logsumexp(inputs, dim=None, keepdim=False):
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path



############## Video related #################
##############################################

def ensure_2d(points):
    """
    If 'points' is 1D (N,) or (N,1), convert it to (N,2).
    If it is already (N,2), return as is.
    """
    points = np.array(points)
    if points.ndim == 1:
        # (N,) -> (N,2)
        points = np.stack([points, np.zeros_like(points)], axis=-1)
    elif points.shape[1] == 1:
        # (N,1) -> (N,2)
        points = np.hstack([points, np.zeros_like(points)])
    return points

def create_combined_image(rgb_image, pca_states_psd, pca_states_metra, current_pca_point_psd, current_pca_point_metra):
    # --- (1) Convert input data to 2D if needed ---
    pca_states_psd = ensure_2d(pca_states_psd)
    pca_states_metra = ensure_2d(pca_states_metra)
    current_pca_point_psd = ensure_2d(current_pca_point_psd.reshape(1, -1))[0]
    current_pca_point_metra = ensure_2d(current_pca_point_metra.reshape(1, -1))[0]

    pil_image = Image.fromarray(rgb_image)
    pca_image_psd = Image.new('RGB', pil_image.size, (255, 255, 255))
    pca_image_metra = Image.new('RGB', pil_image.size, (255, 255, 255))
    draw_psd = ImageDraw.Draw(pca_image_psd)
    draw_metra = ImageDraw.Draw(pca_image_metra)

    width, height = pil_image.size

    # Set margin as 10% of width and height
    margin_width = width * 0.1
    margin_height = height * 0.1

    # Adjusted width and height for plotting
    adjusted_width = width - 2 * margin_width
    adjusted_height = height - 2 * margin_height

    # Min and max of PCA states for psd
    min_x_psd, min_y_psd = pca_states_psd.min(axis=0)
    max_x_psd, max_y_psd = pca_states_psd.max(axis=0)

    # Min and max of PCA states for metra
    min_x_metra, min_y_metra = pca_states_metra.min(axis=0)
    max_x_metra, max_y_metra = pca_states_metra.max(axis=0)

    # Scaling function for psd
    def transform_coordinates_psd(x, y):
        # If max == min, it means no variation on that axis -> fix scale to 0.5 (middle)
        x_scaled = 0.5 if (max_x_psd == min_x_psd) else (x - min_x_psd) / (max_x_psd - min_x_psd)
        y_scaled = 0.5 if (max_y_psd == min_y_psd) else (y - min_y_psd) / (max_y_psd - min_y_psd)
        
        # Convert normalized coordinates to actual image coordinates
        plot_x = margin_width + x_scaled * adjusted_width
        plot_y = height - (margin_height + y_scaled * adjusted_height)
        return plot_x, plot_y

    # Scaling function for metra
    def transform_coordinates_metra(x, y):
        # If max == min, it means no variation on that axis -> fix scale to 0.5 (middle)
        x_scaled = 0.5 if (max_x_metra == min_x_metra) else (x - min_x_metra) / (max_x_metra - min_x_metra)
        y_scaled = 0.5 if (max_y_metra == min_y_metra) else (y - min_y_metra) / (max_y_metra - min_y_metra)
        
        plot_x = margin_width + x_scaled * adjusted_width
        plot_y = height - (margin_height + y_scaled * adjusted_height)
        return plot_x, plot_y

    # Plot all PCA states for psd
    for point in pca_states_psd:
        x, y = transform_coordinates_psd(point[0], point[1])
        draw_psd.ellipse((x-2, y-2, x+2, y+2), fill=(0, 0, 255, 0))

    # Plot all PCA states for metra
    for point in pca_states_metra:
        x, y = transform_coordinates_metra(point[0], point[1])
        draw_metra.ellipse((x-2, y-2, x+2, y+2), fill=(0, 255, 0, 0))

    # Emphasize the current PCA state for psd
    x_psd, y_psd = transform_coordinates_psd(current_pca_point_psd[0], current_pca_point_psd[1])
    draw_psd.ellipse((x_psd-5, y_psd-5, x_psd+5, y_psd+5), fill=(255, 0, 0, 0))

    # Emphasize the current PCA state for metra
    x_metra, y_metra = transform_coordinates_metra(current_pca_point_metra[0], current_pca_point_metra[1])
    draw_metra.ellipse((x_metra-5, y_metra-5, x_metra+5, y_metra+5), fill=(255, 0, 0, 0))

    # Combine images horizontally
    combined_image = Image.new('RGB', (pil_image.width * 3, pil_image.height))
    combined_image.paste(pca_image_metra, (0, 0))
    combined_image.paste(pil_image, (pil_image.width, 0))
    combined_image.paste(pca_image_psd, (pil_image.width * 2, 0))

    return combined_image

def LatentVideoRecorder(all_rgb_arrays, pca_states_psd, pca_states_metra, video_directory, episode_idx, fps):
    # Create directory if it doesn't exist
    if not os.path.exists(video_directory):
        os.makedirs(video_directory)
    
    # Generate video filename
    video_filename = os.path.join(video_directory, f'latent_manifold_{episode_idx}th.mp4')
    
    with imageio.get_writer(video_filename, fps=fps) as writers:
        for _, (rgb_image, pca_point_psd, pca_point_metra) in enumerate(zip(all_rgb_arrays, pca_states_psd, pca_states_metra)):
            combined_image = create_combined_image(
                rgb_image, 
                pca_states_psd, 
                pca_states_metra, 
                pca_point_psd, 
                pca_point_metra
            )
            # Add image to the video
            writers.append_data(np.array(combined_image))


############## Logging related #################
    
def create_directory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Directory '{directory}' created successfully.")
    else:
        print(f"Directory '{directory}' already exists.")

def copy_files_and_directories(source_paths, destination_directory):
    for path in source_paths:
        if os.path.isfile(path):
            shutil.copy(path, destination_directory)
            print(f"File '{path}' copied to '{destination_directory}'.")
        elif os.path.isdir(path):
            destination_path = os.path.join(destination_directory, os.path.basename(path))
            shutil.copytree(path, destination_path)
            print(f"Directory '{path}' copied to '{destination_path}'.")
        else:
            print(f"'{path}' does not exist.")