import os
import pickle
import numpy as np
import pandas as pd
import imageio
from PIL import Image, ImageDraw, ImageFont

# Resizes frames to make file size smaller
def resize_frames(frames, fraction):
    resized_frames = []
    for img in frames:
        new_width = int(img.width * fraction)
        new_height = int(img.height * fraction)
        img_resized = img.resize((new_width, new_height))
        resized_frames.append(np.array(img_resized))

    return resized_frames


# Define function to return image: MPE, Atari
def _label_with_episode_number(frame, episode_num):
    im = Image.fromarray(frame)
    drawer = ImageDraw.Draw(im)
    text_color = (255, 255, 255) if np.mean(frame) < 128 else (0, 0, 0)
    drawer.text(
        (im.size[0] / 20, im.size[1] / 18),
        f"Episode: {episode_num+1}",
        fill=text_color
    )
    return np.array(im)


# Define function to return image: Connect4
def _label_with_episode_number_connect4(frame, episode_num, frame_no, p):
    im = Image.fromarray(frame)
    drawer = ImageDraw.Draw(im)
    text_color = (255, 255, 255)
    font = ImageFont.truetype("./env/connect4/arial.ttf", size=30) #font = ImageFont.load_default()
    drawer.text(
        (100, 5),
        f"Episode: {episode_num+1}     Frame: {frame_no}",
        fill=text_color,
        font=font,
    )
    if p == 1:
        player = "Player 1"
        color = (255, 0, 0)
    if p == 2:
        player = "Player 2"
        color = (100, 255, 150)
    if p is None:
        player = "Self-play"
        color = (255, 255, 255)
    drawer.text((700, 5), f"Agent: {player}", fill=color, font=font)
    return im


def save_pickle(Trajs_data = None, env_name = None, episodes = None):
    os.makedirs(f'data/{env_name}', exist_ok=True)
    if Trajs_data:
        with open(f'data/{env_name}/{episodes}.pkl', 'wb') as f:
            pickle.dump(Trajs_data, f)
        

def save_gif(total_episodes, name, frames, gif_path):
    if total_episodes <= 10:
        os.makedirs(gif_path, exist_ok=True)
        imageio.mimwrite(
            os.path.join(gif_path, f"{name}.gif"), frames, duration=10
        )
        
def save_result(results, name, result_path):
    df = pd.DataFrame(results)
    os.makedirs(result_path, exist_ok=True)
    df.to_csv(os.path.join(result_path, f"{name}.csv"), index=False)


def save_log(log_text, save_name, save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    log_file_path = os.path.join(save_path, save_name)
    with open(log_file_path, "a", encoding="utf-8") as file:
        file.write("=" * 70 + "\n\n")
        file.write(log_text)


def round_stats(x):
    return {
        "min": int(np.min(x)) if len(x) > 0 else 0,
        "max": int(np.max(x)) if len(x) > 0 else 0,
        "avg": float(np.mean(x)) if len(x) > 0 else 0.0
    }

def stat_summary(values):
        arr = np.array(values)
        return {
            "avg": float(np.mean(arr)),
            "min": int(np.min(arr)),
            "max": int(np.max(arr))
        }
        
def round_stat_summary(round_list):
    return {
        "avg": float(np.mean([r["avg"] for r in round_list])),
        "min": int(np.min([r["min"] for r in round_list])),
        "max": int(np.max([r["max"] for r in round_list])),
    }

def insert_action_coord(action, action1, idx=1):
    return action[:idx] + (action1,) + action[idx:]