from PIL import Image, ImageDraw, ImageFont
import os
import random
import math
import copy
import json
import itertools
from collections import deque
from constants import TASK_DIR,NUM_SESSIONS,REVERSAL_TOTAL_TRIALS,STROOP_TOTAL_TRIALS,NBACK_TOTAL_TRIALS,NBACK_NUM_MATCHES,NBACK_N,NBACK_STIMULUS_SET,TOWER_TOTAL_TRIALS,GAMBLING_TOTAL_TRIALS,GAMBLING_STIMULUS_SET

font = ImageFont.truetype("C:/Windows/Fonts/arial.ttf", 60)

def draw_star(draw, bbox, fill, border_thickness):
    x1, y1, x2, y2 = bbox
    center = ((x1 + x2) / 2, (y1 + y2) / 2)
    radius = (x2 - x1) / 2 + 2
    outer_points = []
    inner_points = []
    for i in range(10):
        angle = math.radians(i * 36 - 90)
        if i % 2 == 0:
            out_x = center[0] + radius * math.cos(angle)
            out_y = center[1] + radius * math.sin(angle)
            outer_points.append((out_x, out_y))
        else:
            in_x = center[0] + radius * 0.5 * math.cos(angle)
            in_y = center[1] + radius * 0.5 * math.sin(angle)
            inner_points.append((in_x, in_y))

    star_points = [pt for pair in zip(outer_points, inner_points) for pt in pair]
    draw.polygon(star_points, fill=fill)
    draw.polygon(star_points, outline="black", width=border_thickness)

def create_card(shape, color, count, width=400, height=600):
    img = Image.new("RGB", (width, height), "white")
    draw = ImageDraw.Draw(img)

    shape_size = 120
    adjustment = shape_size / 2
    border_thickness = 6

    if shape == "circle":
        shape_func = lambda bbox, fill: (draw.ellipse(bbox, fill=fill), draw.ellipse([bbox[0]-border_thickness, bbox[1]-border_thickness, bbox[2]+border_thickness, bbox[3]+border_thickness], outline="black", width=border_thickness+1))
    elif shape == "cross":
        shape_func = lambda bbox, fill: (
            draw.line([bbox[0]+shape_size/2, bbox[1]-border_thickness, bbox[0]+shape_size/2, bbox[3]+border_thickness], fill="black", width=44),
            draw.line([bbox[0]-border_thickness, bbox[1]+shape_size/2, bbox[2]+border_thickness, bbox[1]+shape_size/2], fill="black", width=44),
            draw.line([bbox[0]+shape_size/2, bbox[1], bbox[0]+shape_size/2, bbox[3]], fill=fill, width=44-border_thickness*2),
            draw.line([bbox[0], bbox[1]+shape_size/2, bbox[2], bbox[1]+shape_size/2], fill=fill, width=44-border_thickness*2)
        )
    elif shape == "triangle":
        shape_func = lambda bbox, fill: (
            draw.polygon([bbox[0]+shape_size/2, bbox[1], bbox[0], bbox[3], bbox[2], bbox[3]], fill=fill),
            draw.line([(bbox[0]+shape_size/2, bbox[1]), (bbox[0], bbox[3])], fill="black", width=border_thickness),
            draw.line([(bbox[0], bbox[3]), (bbox[2], bbox[3])], fill="black", width=border_thickness),
            draw.line([(bbox[2], bbox[3]), (bbox[0]+shape_size/2, bbox[1])], fill="black", width=border_thickness)
        )
    elif shape == "star":
        shape_func = lambda bbox, fill: draw_star(draw, bbox, fill, border_thickness)

    if count == 1:
        positions = [(width//2 - adjustment, height//2 - adjustment)]
    elif count == 2:
        positions = [(width//4 - adjustment, height//4 - adjustment), (3*width//4 - adjustment, 3*height//4 - adjustment)]
    elif count == 3:
        positions = [(width//4 - adjustment, height//4 - adjustment), (width//4 - adjustment, 3*height//4 - adjustment), (3*width//4 - adjustment, height//2 - adjustment)]
    elif count == 4:
        positions = [(width//4 - adjustment, height//4 - adjustment), (3*width//4 - adjustment, height//4 - adjustment), (width//4 - adjustment, 3*height//4 - adjustment), (3*width//4 - adjustment, 3*height//4 - adjustment)]

    for pos in positions:
        bbox = [pos[0], pos[1], pos[0] + shape_size, pos[1] + shape_size]
        shape_func(bbox, color)

    return img

def create_WSCT(out_path, num_sequences=1):
    shapes = [ "triangle", "star", "cross","circle"]
    colors = ["red", "green", "yellow", "blue"]
    counts = [1, 2, 3, 4]

    fixed_cards_info = [(shapes[0], colors[0], counts[0]), (shapes[1], colors[1], counts[1]), (shapes[2], colors[2], counts[2]), (shapes[3], colors[3], counts[3])]
    fixed_cards = [(create_card(shape, color, count), f"{shape}_{color}_{count}") for shape, color, count in fixed_cards_info]
    
    all_cards_info = [(shape, color, count) for shape in shapes for color in colors for count in counts]
    # all_cards_info = [card for card in all_cards_info if card not in fixed_cards_info]
    
    all_cards = [(create_card(shape, color, count), f"{shape}_{color}_{count}") for shape, color, count in all_cards_info]
    
    for sequence in range(num_sequences):
        random.shuffle(all_cards)
        output_folder = out_path+f"trial{sequence+1}"+"/cards"
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        canvas_width = 2000
        canvas_height = 2400
        count_i = 0
        metadata = {}
        for random_card, random_name in all_cards:
            canvas = Image.new("RGB", (canvas_width, canvas_height), "black")
            x_offset = 50
            y_offset = 50
            for card, name in fixed_cards:
                canvas.paste(card, (x_offset, y_offset))
                x_offset += 400 + 100
            canvas.paste(random_card, (50, 1700))
            filename = f"{count_i}_{random_name}.png"
            canvas.save(os.path.join(output_folder, filename))
            
            shape, color, count = random_name.split('_')
            metadata[filename] = {
                "color": color,
                "shape": shape,
                "number": int(count),
                "trialNumber": count_i + 1,
                "image": out_path+f"trial{sequence+1}"+f"/cards/{filename}",
                "colorRule": colors.index(color),
                "shapeRule": shapes.index(shape),
                "numberRule": counts.index(int(count))
            }
            
            count_i += 1
        with open(os.path.join(out_path+f"trial{sequence+1}", 'cards.json'), 'w') as f:
            json.dump(metadata, f, indent=2)
    
if __name__ == "__main__":
    # "This file is used to generate all tasks' datasets"
    create_WSCT(out_path=TASK_DIR['WCST'], num_sequences=NUM_SESSIONS)