import torch
import math
import cv2
import numpy as np
import os

import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset

class CombinedDataset(Dataset):
    def __init__(self, dataset_1, dataset_2):
        self.dataset_1 = dataset_1
        self.dataset_2 = dataset_2
        # Combine lengths of both datasets
        self.length = len(dataset_1) + len(dataset_2)

    def __getitem__(self, index):
        # Adjust index if it's pointing to the second dataset's items
        if index < len(self.dataset_1):
            return self.dataset_1[index]
        else:
            return self.dataset_2[index - len(self.dataset_1)]

    def __len__(self):
        return self.length

def draw_kps_image(image, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)]):
    '''
    takes an input image and a set of keypoints (kps), along with an optional color list
    and draws these keypoints and the connections between them onto the image
    '''
    # sets the width of the lines connecting keypoints (stick_width) and defines a sequence of limb connections (limb_seq). 
    # The limb_seq array specifies which keypoints should be connected to draw limbs (e.g., connections between hips, shoulders, etc.). 
    # It assumes a simple configuration here with connections from keypoint 0 to 2 and 1 to 2 (COCO keypoints format).
    stick_width = 4
    limb_seq = np.array([[0, 2], [1, 2]])
    kps = np.array(kps)

    canvas = image

    for i in range(len(limb_seq)):
        # For each pair of keypoints defined in limb_seq, it calculates the distance and angle between them.
        # creates an elliptical polygon (representing the limb) that connects the keypoints, with the width set by stick_width.
        # The polygon is filled with a color that corresponds to the start point of the limb connection, slightly darkened (multiplied by 0.6) for aesthetic purposes.
        index = limb_seq[i]
        color = color_list[index[0]]

        x = kps[index][:, 0]
        y = kps[index][:, 1]
        length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
        angle = int(math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])))
        polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), angle, 0, 360, 1)
        cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])

    for idx_kp, kp in enumerate(kps):
        color = color_list[idx_kp]
        x, y = kp
        cv2.circle(canvas, (int(x), int(y)), 4, color, -1)

    return canvas

def load_kps_images(kps_path="", kps_sequence=None, video_length=1, image_height=512, image_width=512):

    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    # keypoints prior
    if kps_path != "":
        assert os.path.exists(kps_path), f'{kps_path} does not exist'
        kps_sequence = torch.tensor(torch.load(kps_path))  # [len, 3, 2]
    elif kps_sequence is None:
        raise NotImplementedError('KPS Sequence is not defined!')

    if isinstance(kps_sequence, list):
        kps_sequence = torch.tensor(kps_sequence)
    if kps_sequence.ndim == 2:
        kps_sequence = kps_sequence.unsqueeze(0)
    if kps_sequence.shape[1] > 3:
        kps_sequence = kps_sequence[:, :3, :]
    kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
    kps_sequence = kps_sequence.permute(2, 0, 1)

    kps_images = []
    for i in range(video_length):
        kps_image = np.zeros((image_height, image_width, 3), dtype=np.uint8)
        kps_image = draw_kps_image(kps_image, kps_sequence[i])
        kps_images.append(transform(kps_image))

    return kps_images
