import math
import torch
import os
import shutil

import imageio
import numpy as np

from PIL import Image, ImageDraw

def create_log_gaussian(mean, log_std, t):
    quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
    l = mean.shape
    log_z = log_std
    z = l[-1] * math.log(2 * math.pi)
    log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
    return log_p

def logsumexp(inputs, dim=None, keepdim=False):
    if dim is None:
        inputs = inputs.view(-1)
        dim = 0
    s, _ = torch.max(inputs, dim=dim, keepdim=True)
    outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
    if not keepdim:
        outputs = outputs.squeeze(dim)
    return outputs

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path



############## Video related #################
##############################################

############ Camera image ############
class VideoRecorder(object):
    def __init__(self, dir_name, height=256, width=256, fps=200):
        self.dir_name = dir_name
        try:
            if not os.path.exists(dir_name):
                os.makedirs(dir_name)
        except OSError:
            print ('Error: Creating directory. ' +  dir_name)
    
        self.height = height
        self.width = width
        self.fps = fps
        self.frames = []
        
    def init(self, enabled=True):
        self.frames = []
        self.enabled = self.dir_name is not None and enabled

    def record(self, obs):
        if self.enabled:
            self.frames.append(obs)
        
    def save(self, file_name):
        if self.enabled:
            path = os.path.join(self.dir_name, file_name)
            imageio.mimsave(path, self.frames, fps=self.fps)

############ Camera + Latent image ############
        

def LatentVideoRecorder(all_rgb_arrays, all_rgb_arrays_train, pca_states, video_directory, episode_idx, fps):
    if not os.path.exists(video_directory):
        os.makedirs(video_directory)

    video_filename = os.path.join(video_directory, f'latent_manifold_{episode_idx}th.mp4')
    
    writer = imageio.get_writer(video_filename, fps=fps)

    for idx, (rgb_image, rgb_train_image, pca_point) in enumerate(zip(all_rgb_arrays, all_rgb_arrays_train, pca_states)):
        combined_image = create_combined_image(rgb_image, rgb_train_image, pca_states, pca_point)
        

        writer.append_data(np.array(combined_image))
    
    writer.close()

def create_combined_image(rgb_image, rgb_train_image, pca_states, current_pca_point):
    pil_image = Image.fromarray(rgb_image)
    pil_train_image = Image.fromarray(rgb_train_image)


    pca_image = Image.new('RGB', pil_image.size, (255, 255, 255))
    draw = ImageDraw.Draw(pca_image)

    width, height = pil_image.size


    margin_width = width * 0.1
    margin_height = height * 0.1

    adjusted_width = width - 2 * margin_width
    adjusted_height = height - 2 * margin_height


    min_x, min_y = pca_states.min(axis=0)
    max_x, max_y = pca_states.max(axis=0)


    def transform_coordinates(x, y):
        x_scaled = (x - min_x) / (max_x - min_x) if max_x - min_x != 0 else 0.5
        y_scaled = (y - min_y) / (max_y - min_y) if max_y - min_y != 0 else 0.5
        return margin_width + x_scaled * adjusted_width, height - (margin_height + y_scaled * adjusted_height)


    for point in pca_states:
        x, y = transform_coordinates(point[0], point[1])
        draw.ellipse((x - 2, y - 2, x + 2, y + 2), fill=(0, 0, 255))

    x, y = transform_coordinates(current_pca_point[0], current_pca_point[1])
    draw.ellipse((x - 5, y - 5, x + 5, y + 5), fill=(255, 0, 0))

    separator_width = 2

    total_width = (
        pil_image.width + pca_image.width + pil_train_image.width + 2 * separator_width
    )
    combined_image = Image.new('RGB', (total_width, pil_image.height), (0, 0, 0))  

    combined_image.paste(pil_image, (0, 0))
    combined_image.paste(pil_train_image, (pil_image.width + separator_width, 0))
    combined_image.paste(pca_image, (pil_image.width + pca_image.width + 2 * separator_width, 0),
    )

    return combined_image

############ Camera + Latent image for eval ############
    
def create_combined_image_eval(rgb_image, pca_states, current_pca_point, history_points):
    pil_image = Image.fromarray(rgb_image)
    pca_image = Image.new('RGB', pil_image.size, (255, 255, 255))
    draw = ImageDraw.Draw(pca_image)

    width, height = pil_image.size

    # set margin (10% of width and height)
    margin_width = width * 0.1
    margin_height = height * 0.1

    # adjusted width and height
    adjusted_width = width - 2 * margin_width
    adjusted_height = height - 2 * margin_height

    # min/ max of PCA states
    min_x, min_y = pca_states.min(axis=0)
    max_x, max_y = pca_states.max(axis=0)

    # scaling
    def transform_coordinates(x, y):
        # normalize to [0,1]
        x_scaled = (x - min_x) / (max_x - min_x)
        y_scaled = (y - min_y) / (max_y - min_y)
        return margin_width + x_scaled * adjusted_width, height - (margin_height + y_scaled * adjusted_height)

    # Plot history PCA states
    for (point, alpha) in history_points:
        x, y = transform_coordinates(point[0], point[1])
        color = (0, 0, 255, int(alpha * 255))
        draw.ellipse((x-2, y-2, x+2, y+2), fill=color)

    # Emphasize current PCA state
    x, y = transform_coordinates(current_pca_point[0], current_pca_point[1])
    draw.ellipse((x-5, y-5, x+5, y+5), fill=(255,0,0,0))

    # Combine
    combined_image = Image.new('RGB', (pil_image.width * 2, pil_image.height))
    combined_image.paste(pil_image, (0, 0))
    combined_image.paste(pca_image, (pil_image.width, 0))

    return combined_image

def LatentVideoRecorder_eval(all_rgb_arrays, pca_states, video_directory, video_name, fps):

    if not os.path.exists(video_directory):
        os.makedirs(video_directory)
    
    # Combine the path and video filename
    video_filename = os.path.join(video_directory, f'latent_manifold_{video_name}.mp4')
    
    writers = imageio.get_writer(video_filename, fps=fps)

    history_points = []
    alpha_decay = 0.8  # alpha decay factor for transparency reduction

    for frame_idx, (rgb_image, pca_point) in enumerate(zip(all_rgb_arrays, pca_states)):
        # Decrease the alpha values of existing history_points and update them
        history_points = [(point, alpha * alpha_decay) for point, alpha in history_points if alpha * alpha_decay > 0.001]

        # Add the current point to history_points
        if frame_idx > 0:  # Create a blue dot after the red dot has passed
            history_points.append((pca_states[frame_idx - 1], 1.0))

        combined_image = create_combined_image_eval(rgb_image, pca_states, pca_point, history_points)
        
        # Add the image to the video
        writers.append_data(np.array(combined_image))

    writers.close()


############## Logging related #################
    
def create_directory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Directory '{directory}' created successfully.")
    else:
        print(f"Directory '{directory}' already exists.")

def copy_files_and_directories(source_paths, destination_directory):
    for path in source_paths:
        if os.path.isfile(path):
            shutil.copy(path, destination_directory)
            print(f"File '{path}' copied to '{destination_directory}'.")
        elif os.path.isdir(path):
            destination_path = os.path.join(destination_directory, os.path.basename(path))
            shutil.copytree(path, destination_path)
            print(f"Directory '{path}' copied to '{destination_path}'.")
        else:
            print(f"'{path}' does not exist.")