import os
import time
import torch
import argparse
import numpy as np
import pandas as pd

from model.decoder_eos import SVGDecoder
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.svg import SVG
from deepsvg.svglib.geom import Bbox
from transformers import AutoTokenizer


os.environ["TOKENIZERS_PARALLELISM"] = "false"


NUM_SAMPLE = 16
BS = 4
BBOX = 200
str_num = 30
#write your own texts or use the provided ones
texts = ['phone']

SVG_END = 0
MOVE = 1
LINE = 2
CURVE = 3

def sample(args, cfg):
    device = torch.device("cuda:0")
    tokenizer = AutoTokenizer.from_pretrained(cfg['tokenizer_name'])

    total_command_len = 0
    generated_sample_num = 0

    svg_decoder = SVGDecoder(
        config={
            'hidden_dim': 1024,
            'embed_dim': 512,
            'num_layers': 16,
            'num_heads': 8,
            'dropout_rate': 0.1
        },
        command_len=cfg['command_len'],
        text_len=cfg['text_len'],
        num_text_token=tokenizer.vocab_size,
        word_emb_path=cfg['word_emb_path'],
        pos_emb_path=cfg['pos_emb_path'],
        length_loss_weight=0.001,
        eos_alpha=0.005,
    )
    svg_decoder.load_state_dict(torch.load(os.path.join(args.svg_weight, 'pytorch_model.bin')))
    svg_decoder = svg_decoder.to(device).eval()

    if not os.path.exists(args.output):
        os.makedirs(args.output)


    for i, text in enumerate(texts):
        text = text.split('/')
        text = ','.join(text)
        print(f'Generate SVG for "{text}"...')

        output_dir = os.path.join(args.output, text)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # tokenize text input
        encoded_dict = tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=cfg['text_len'],
            add_special_tokens=True,
            return_token_type_ids=False,  # for RoBERTa
        )
        tokenized_text = encoded_dict["input_ids"].squeeze()
        tokenized_text = tokenized_text.repeat(BS, 1).to(device)

        # sample SVG
        generated_svg = []
        start_time = time.time()
        while len(generated_svg) < NUM_SAMPLE:
            sample_commands = svg_decoder.sample(n_samples=BS, text=tokenized_text)
            generated_svg += sample_commands
        end_time = time.time()
        print(f'Generate {len(generated_svg)} svg in {end_time - start_time} seconds')

        # convert token sequence into SVG
        print('Rendering...')
        gen_data = []
        for sample_command in generated_svg:
            data, command_len = raster_svg(sample_command)
            gen_data += data
            total_command_len += command_len
            generated_sample_num += 1

        print('Saving...')
        for index, data in enumerate(gen_data):
            try:
                paths = []
                for d in data:
                    path = SVGTensor.from_data(d)
                    path = SVG.from_tensor(path.data, viewbox=Bbox(BBOX))
                    path.fill_(True)
                    paths.append(path)
                path_groups = paths[0].svg_path_groups
                for k in range(1, len(paths)):
                    path_groups.extend(paths[k].svg_path_groups)
                svg = SVG(path_groups, viewbox=Bbox(BBOX))
                svg.save_svg(os.path.join(output_dir, f'{str(index).zfill(5)}.svg'))
            except Exception as err_msg:
                print(err_msg)
                continue

    print('Average command length : %.2f' % (total_command_len / generated_sample_num))

"""
0: SVG END
1: M
2: L
3: C
"""


def raster_svg(batch_commands):
    try:
        '''Dimension of Token : Command dim(5) + start_pos(2) + c1 (2) + c2(2) + end_pos(2) = 13 '''

        svg_tensors = []
        path_tensor = []
        command_len = 0
        for i, commands in enumerate(batch_commands):
            # path_tensor = []
            # COMMAND = 0
            # START_POS = [1, 3)
            # CONTROL1 = [3, 5)
            # CONTROL2 = [5, 7)
            # END_POS = [7, 9)
            for command in commands:
                if command[MOVE] > 0.5:  # Move
                    cmd_tensor = np.zeros(9)
                    cmd_tensor[0] = 0
                    cmd_tensor[7:9] = (command[10:] * (BBOX/2)) + (BBOX/2)
                    start_pos = np.round((command[4:6] * (BBOX/2)) + (BBOX/2))
                    end_pos = np.round(cmd_tensor[7:9])
                    cmd_tensor[cmd_tensor < 0] = 0
                    if np.all(start_pos == end_pos) and  path_tensor:
                        svg_tensors.append(torch.tensor(path_tensor))
                        command_len += len(path_tensor)
                        path_tensor = []
                    path_tensor.append(cmd_tensor.tolist())
                elif command[LINE] > 0.5 :  # Line
                    cmd_tensor = np.zeros(9)
                    cmd_tensor[0] = 1
                    cmd_tensor[7:9] = (command[10:] * (BBOX/2)) + (BBOX/2)
                    cmd_tensor[cmd_tensor < 0] = 0
                    path_tensor.append(cmd_tensor.tolist())
                elif command[CURVE] > 0.5:  # Curve
                    cmd_tensor = np.zeros(9)
                    cmd_tensor[0] = 2
                    cmd_tensor[3:5] = (command[6:8] * (BBOX/2)) + (BBOX/2)
                    cmd_tensor[5:7] = (command[8:10] * (BBOX/2)) + (BBOX/2)
                    cmd_tensor[7:9] = (command[10:12] * (BBOX/2)) + (BBOX/2)
                    cmd_tensor[cmd_tensor < 0] = 0
                    path_tensor.append(cmd_tensor.tolist())
        svg_tensors.append(torch.tensor(path_tensor))
        command_len += len(path_tensor)
        return [svg_tensors] , command_len
    except Exception as error_msg:
        print(error_msg)
        print('with the error raster finished')
        return []


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--svg_weight", type=str, required=True)
    args = parser.parse_args()

    cfg = {
        'command_len': 256,
        'text_len': 50,

        'tokenizer_name': 'google/bert_uncased_L-12_H-512_A-8',
        'word_emb_path': 'ckpts/word_embedding_512.pt',
        'pos_emb_path': None,
    }

    sample(args, cfg)
