import torch
import numpy as np
import random
import pandas as pd
import os
from typing import List, Tuple

from deepsvg.svglib.svg import SVG
from deepsvg.svglib.geom import Point

"""
0: SVG END
1: M
2: L
3: C
"""

SVG_END = 0
MOVE = 1
LINE = 2
CURVE = 3
PIX_PAD = 4

CMD_TENSOR_DIM = 12

'''Tokens : SVG END, Move, Line, Curve'''
'''Dimension of Token : Command dim(4) + start_pos(2) + c1 (2) + c2(2) + end_pos(2) = 12 '''

BBOX = 200
AUG_RANGE = 3

class SVGData(torch.utils.data.Dataset):
    '''SVG Dataset'''

    def __init__(self,meta_file_path, svg_folder, MAX_LEN, text_len, tokenizer, require_aug):
        self.maxlen = MAX_LEN

        mf = pd.read_csv(meta_file_path)
        mf = mf[(1<mf.len_pix) & (mf.len_pix+PIX_PAD<=2*self.maxlen)]
        self.maxlen_pix = MAX_LEN
        self.meta_file = mf
        self.svg_folder = svg_folder

        self.tokenizer = tokenizer
        self.text_len = text_len
        self.num_text_token = self.tokenizer.vocab_size


        self.uids = sorted(list(set(mf['id'].values)))
        self.require_aug = require_aug

    def __len__(self):
        return len(self.uids)

    def prepare_batch_SVG(self, command_v):
        keys = np.ones(len(command_v))
        padding_key = np.zeros(self.maxlen_pix -len(command_v)).astype(int)
        padding = np.zeros((self.maxlen_pix-len(command_v),CMD_TENSOR_DIM)).astype(float)
        command_v_flat = np.concatenate([command_v, padding], axis=0)
        mask_idx = (1-np.concatenate([keys, padding_key])) == 1

        diffusion_mask = np.zeros((len(command_v_flat),8)) == 1
        move_idx = command_v_flat[:,MOVE] == 1
        diffusion_mask[move_idx,:2] = True
        diffusion_mask[move_idx,6:] = True
        line_idx = command_v_flat[:,LINE] == 1
        diffusion_mask[line_idx,6:] = True
        curve_idx = command_v_flat[:,CURVE] == 1
        diffusion_mask[curve_idx,2:] = True
        return command_v_flat, mask_idx, diffusion_mask

    def __getitem__(self, idx):
        uid = self.uids[idx]

        rand = torch.rand(1).item()
        if rand < 0.8:
            text = self.meta_file[self.meta_file.id==uid].label.values[0] # FIGR
            text = text.split('/')
            random.shuffle(text)
            text = ','.join(text)
        elif rand < 0.9:
            text = self.meta_file[self.meta_file.id==uid].desc.values[0] # FIGR
        else:
            text = ''

        encoded_dict = self.tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.text_len,
            add_special_tokens=True,
            return_token_type_ids=False,  # for RoBERTa
        )
        text = encoded_dict["input_ids"].squeeze()

        svg_file = os.path.join(self.svg_folder, f'{uid}.svg')

        svg = SVG.load_svg(svg_file)

        if self.require_aug:
            dx = random.randint(-AUG_RANGE, AUG_RANGE)
            dy = random.randint(-AUG_RANGE, AUG_RANGE)
            svg.translate(Point(dx, dy))
        else:
            svg.drop_z()

        svg_tensors = svg.to_tensor(concat_groups=False, PAD_VAL=0)
        vec_data = get_vec_data(svg_tensors)

        command_tensor = np.zeros(CMD_TENSOR_DIM,dtype=np.float)
        command_tensor[4:] = (command_tensor[4:] - (BBOX / 2)) / (BBOX / 2)
        command_tensor[SVG_END] = 1

        command_tokens = vec_data['se_command']
        command_tokens.append(command_tensor)
        commands = np.array(command_tokens)

        command_seq, mask, diffusion_mask = self.prepare_batch_SVG(commands)
        command_seq = torch.from_numpy(command_seq)
        return command_seq, mask, text, diffusion_mask



def get_vec_data(svg_tensors):
    se_pix = []
    se_command = []
    command_len = 0
    '''Dimension of Token : Command dim(5) + start_pos(2) + c1 (2) + c2(2) + end_pos(2) = 13 '''

    for path_tensor in svg_tensors:
        path_tensor = torch.clip(path_tensor, min=0, max=BBOX - 1)
        path_pix = []
        for i, cmd_arg_tensor in enumerate(path_tensor):
            cmd = cmd_arg_tensor[0].round().int()
            start_pos = cmd_arg_tensor[1:3].numpy()
            control1 = cmd_arg_tensor[3:5].numpy()
            control2 = cmd_arg_tensor[5:7].numpy()
            end_pos = cmd_arg_tensor[7:9].numpy()

            command_tensor = np.zeros(CMD_TENSOR_DIM, dtype=np.float)
            command_tensor[4:] = (command_tensor[4:] - (BBOX / 2)) / (BBOX / 2)

            if cmd == 0:  # Move
                if i == 0:
                    command_tensor[MOVE] = 1
                    command_tensor[4] = (end_pos[0] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[5] = (end_pos[1] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[10] = (end_pos[0] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[11] = (end_pos[1] - (BBOX / 2)) / (BBOX / 2)
                    se_command.append(command_tensor)

                    path_pix.append(MOVE)
                    path_pix.append(num2index(np.round(start_pos).astype(int)) + PIX_PAD)
                    path_pix.append(num2index(np.round(end_pos).astype(int)) + PIX_PAD)
                else:
                    command_tensor[MOVE] = 1
                    command_tensor[4] = (start_pos[0] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[5] = (start_pos[1] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[10] = (end_pos[0] - (BBOX / 2)) / (BBOX / 2)
                    command_tensor[11] = (end_pos[1] - (BBOX / 2)) / (BBOX / 2)
                    se_command.append(command_tensor)

                    path_pix.append(MOVE)
                    path_pix.append(num2index(np.round(start_pos).astype(int)) + PIX_PAD)
                    path_pix.append(num2index(np.round(end_pos).astype(int)) + PIX_PAD)
            elif cmd == 1:  # Line
                command_tensor[LINE] = 1
                command_tensor[10] = (end_pos[0] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[11] = (end_pos[1] - (BBOX / 2)) / (BBOX / 2)
                se_command.append(command_tensor)

                path_pix.append(LINE)
                path_pix.append(num2index(np.round(end_pos).astype(int)) + PIX_PAD)
            else:  # Curve
                command_tensor[CURVE] = 1
                command_tensor[6] = (control1[0] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[7] = (control1[1] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[8] = (control2[0] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[9] = (control2[1] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[10] = (end_pos[0] - (BBOX / 2)) / (BBOX / 2)
                command_tensor[11] = (end_pos[1] - (BBOX / 2)) / (BBOX / 2)
                se_command.append(command_tensor)

                path_pix.append(CURVE)
                path_pix.extend([
                    num2index(np.round(control1).astype(int)) + PIX_PAD,
                    num2index(np.round(control2).astype(int)) + PIX_PAD,
                    num2index(np.round(end_pos).astype(int)) + PIX_PAD,
                ])

        se_pix.append(np.array(path_pix))


    num_se = len(svg_tensors)
    command_len = len(se_command)

    vec_data = {
        'len_command': command_len,
        'num_se': num_se,
        'se_pix': se_pix,
        'se_command': se_command,
    }
    return vec_data

def num2index(n: np.array) -> int:
    return n[0] + n[1] * BBOX