import transformers
import cv2
import os
import json
import torch
import copy

import numpy as np
from PIL import Image
from typing import Dict, Optional, Sequence, List
from torch.utils.data import Dataset
from torchvision import transforms

import pycocotools.mask as maskUtils

from .base_class import DataArguments, rank0_print, preprocess, preprocess_multimodal
from TraceVLM.utils import box_xyxy_expand2square, resize_box, resize_seg, resize_mask, resize_localnarrative, reorder, reorder_seg_and_merge, get_single_centerpoint, get_all_centerpoint, reorder_mask_and_merge
from TraceVLM.utils import is_clockwise, reorder_points, revert_direction, interpolate_polygons, approximate_polygons, polygons_to_string, sample_polygons, calculate_angle
from TraceVLM.constants import REF_TASK_PALCEHOLDER, DEFAULT_SEGES_PLACEHOLDER, DEFAULT_LNC_PLACEHOLDER, DEFAULT_LNC_TOKEN
from skimage.measure import label, regionprops
from skimage.draw import line
import random
from pycocotools.coco import COCO
from scipy import interpolate


def trace2coord(traces, timed_caption):
    """Computing the average location with intergral"""
    xs = []
    ys = []
    ts = []
    for t in traces:
        x = t['x']
        x = max(x, 0.0)
        x = min(x, 1.0)
        y = t['y']
        y = max(y, 0.0)
        y = min(y, 1.0)
        xs.append(x)
        ys.append(y)
        ts.append(t['t'])
    t_arr = np.array(ts)
    x_arr = np.array(xs)
    y_arr = np.array(ys)

    # get the indices that would sort t_arr
    sort_indices = np.argsort(t_arr)

    t_arr = t_arr[sort_indices]
    x_arr = x_arr[sort_indices]
    y_arr = y_arr[sort_indices]
    num_points = 16

    toks = []

    x_interpolator = interpolate.interp1d(
            t_arr, x_arr, fill_value='extrapolate'
        )
    y_interpolator = interpolate.interp1d(
        t_arr, y_arr, fill_value='extrapolate'
    )

    for dic in timed_caption:
        toks.append(dic['utterance'].encode('utf-8'))
        # time_begins.append(dic['start_time'])
        # time_ends.append(dic['end_time'])

        time_begin = dic['start_time']
        time_end = dic['end_time']
        cur_segmented_xs = []
        cur_segmented_ys = []
        cur_segmented_ts = []
        for j, t in enumerate(t_arr):
            if t > time_end:
                break
            if t >= time_begin and t <= time_end:
                cur_segmented_xs.append(x_arr[j])
                cur_segmented_ys.append(y_arr[j])
                cur_segmented_ts.append(t)

        t_values = np.linspace(time_begin, time_end, num=num_points)
        x_values = x_interpolator(t_values)
        new_x_values = []
        for x_value in x_values:
            x_value = max(x_value, 0.0)
            x_value = min(x_value, 1.0)
            new_x_values.append(x_value)
        y_values = y_interpolator(t_values)
        new_y_values = []
        for y_value in y_values:
            y_value = max(y_value, 0.0)
            y_value = min(y_value, 1.0)
            new_y_values.append(y_value)
        if t_values[-1] - t_values[0] < 1e-5:
            integral_x = np.mean(new_x_values)
            integral_y = np.mean(new_y_values)
        else:
            # calculate integral (average) x and y values
            integral_x = np.trapz(new_x_values, t_values) / (t_values[-1] - t_values[0])
            integral_y = np.trapz(new_y_values, t_values) / (t_values[-1] - t_values[0])
        dic['integral_x'] = integral_x
        dic['integral_y'] = integral_y
        dic['min_x'] = np.min(x_values)
        dic['min_y'] = np.min(y_values)
        dic['max_x'] = np.max(x_values)
        dic['max_y'] = np.max(y_values)
        dic['sampled_x'] = x_values.tolist()
        dic['sampled_y'] = y_values.tolist()

    return timed_caption

class ImgLocalNarrativeCaptionDataSet(Dataset):
    "dataset for Local Narratives Caption"
    def __init__(self, data_path: str,
                tokenizer: transformers.PreTrainedTokenizer,
                data_args: DataArguments,
                image_folder: str = None,
                template_file: str = None,
                split: str = "<SPL>",
                prompt: str = None,
                debug=False,
                coco_instance_json: str = None,
                list_data_dict: Optional[List[str]] = None,
            ):
        super(ImgLocalNarrativeCaptionDataSet, self).__init__()

        if list_data_dict is None:
            f = open(data_path, "r", encoding="utf-8")
            list_data_dict = f.readlines()
        # coco_instance_json = "/public/yangfan/coco2017/annotations/instances_train2017.json"
        # coco_instance_json = "/storage-root/datasets/yangfan/coco2017/annotations/instances_train2017.json"
        coco = COCO(coco_instance_json)
        image_ids = coco.getImgIds()
        self.coco = coco

        rank0_print("Formatting inputs...LocalNarrative Caption and Trajectory Dataset")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args
        self.image_folder = image_folder
        assert not(prompt != None and template_file != None)

        if template_file is not None:
            self.prompts = json.load(open(template_file, 'r', encoding='utf8'))
            self.rng = np.random.default_rng(1203)
            self.prompt = None
        else:
            self.prompt = prompt # "<image><expr><single>"
            self.prompts = None

        # self.lncformat = BinLNCFormatter(self.data_args.image_processor, lnc_split_placeholder=self.data_args.seg_split, precision=3)

        # data format # NumberSegFormatter、DictNumber、Bin、BinPolar、BinPolarInstance、BinPolar36point
        # if data_args.formatter.lower() == "NumberSegFormatter".lower():
        #     self.segformat = NumberSegFormatter(self.data_args.image_processor, precision=3)
        # elif data_args.formatter.lower() == "DictNumber".lower():
        #     self.segformat = DictNumberSegFormatter(self.data_args.image_processor, precision=3)
        # elif data_args.formatter.lower() == "Bin".lower():  # 得需要顺序
        # self.segformat = BinSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        # elif data_args.formatter.lower() == "BinPolar".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
        #     self.segformat = BinPolarSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        # elif data_args.formatter.lower() == "BinPolarGD".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
        #     self.segformat = BinPolarGDSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        # elif data_args.formatter.lower() == "BinPolarGDAngle".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
        #     self.segformat = BinPolarAngleSegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)
        # elif data_args.formatter.lower() == "BinPolarGDAngle7".lower():   # 得需要确认是1.只要最大的polygon；2.所有polygon结合; 3.所有polygon都要
        #     self.segformat = BinPolarAngle7SegFormatter(self.data_args.image_processor, seg_split_placeholder=self.data_args.seg_split, precision=3)

        # image resize
        self.debug = debug
        if not self.data_args.image_processor.do_resize:
            self.resize = transforms.Resize((self.data_args.image_processor.size["shortest_edge"], self.data_args.image_processor.size["shortest_edge"]))

    def __len__(self):
        return len(self.list_data_dict)
    
    def transform2conv(self, i):
        if self.prompts is not None:
            prompt = self.rng.choice(self.prompts)
        else:
            prompt = self.prompt
        if not prompt.endswith("\n"):
            prompt += "\n"
        if "<image>" not in prompt:
            prompt = "<image>" + prompt

        # prompt = DEFAULT_LNC_TOKEN + prompt

        item = json.loads(self.list_data_dict[i])
        dataset_id = item['dataset_id']
        if 'coco' in dataset_id:
            image_id = item['image_id']
            image_id_str = str(image_id).zfill(12)
            image_filename = f"{image_id_str}.jpg"
            img_path = os.path.join(self.image_folder, image_filename)
            image_info = self.coco.loadImgs(int(image_id))[0]
            ann_ids = self.coco.getAnnIds(int(image_id))
            anns = self.coco.loadAnns(ann_ids)
            width, height = image_info['width'], image_info['height']
            bbox = np.asarray(
                [x['bbox'] for x in anns], dtype=np.float32).reshape(-1, 4)
            bbox[:, 2:] = bbox[:, 2:] + bbox[:, :2]
        elif 'Flickr30k' in dataset_id:
            image_filename = f"{item['image_id']}.jpg"
            img_path = os.path.join(self.image_folder, image_filename)
            with Image.open(img_path) as img:
                width, height = img.size
            bbox = np.array([[0, 0, 1, 1]])
        elif 'ADE20K' in dataset_id:
            image_filename = f"{item['image_id']}.jpg"
            img_path = os.path.join(self.image_folder, image_filename)
            with Image.open(img_path) as img:
                width, height = img.size
            bbox = np.array([[0, 0, 1, 1]])
        else:
            image_filename = f"{item['image_id']}.jpg"
            img_path = os.path.join(self.image_folder, image_filename)
            with Image.open(img_path) as img:
                width, height = img.size
            bbox = np.array([[0, 0, 1, 1]])
        bbox = bbox.tolist()

        expr = item["caption"]
        timed_caption = item['timed_caption']
        data_traces = item['traces']
        traces = []
        for trace in data_traces:
            traces.extend(trace)

        if len(traces) <= 1:
            traces = None

        if traces != None:

            timed_caption = trace2coord(traces, timed_caption)
            center_np = np.array(
                [[dic['integral_x'], dic['integral_y']] for dic in timed_caption]
            )
            trace_lines = center_np
            scaled_xs = trace_lines[:, 0] * width
            scaled_ys = trace_lines[:, 1] * height
            localnarrative_scaled_list = [[x, y] for x, y in zip(scaled_xs, scaled_ys)]

            # local_narrative_list = []
            local_narrative_list, _, _ = resize_localnarrative(localnarrative_scaled_list, self.data_args.image_processor.size, height, width, pre=True)
            # bbox, _, _ = resize_box(bbox, self.data_args.image_processor.size, height, width, pre=True)
            # sentence = "{}".format(DEFAULT_LNC_PLACEHOLDER)

            local_narrative_numpy_list = []
            for each_localnarrative_list in local_narrative_list:
                flattened_array = np.array(each_localnarrative_list).flatten()
                each_localnarrative_numpy = np.round(flattened_array).astype(int)
                local_narrative_numpy_list.append(each_localnarrative_numpy)

            scale_size = min(height, width)
            local_narrative_strlist = []
            
            for each_local_narrative_prompt in local_narrative_numpy_list:
                x, y = each_local_narrative_prompt
                norm_x = round(max(0.0, min(float(x) / scale_size, 1.0)), 3)
                norm_y = round(max(0.0, min(float(y) / scale_size, 1.0)), 3)
                each_local_narrative_prompt_list = []
                each_local_narrative_prompt_list.append("<bin_{}>".format(round(norm_x * 1000)))
                each_local_narrative_prompt_list.append("<bin_{}>".format(round(norm_y * 1000)))
                local_narrative_strlist.extend(each_local_narrative_prompt_list)

            local_narrative_str = ""

            for idx, caption in enumerate(timed_caption):
                if idx * 2 < len(local_narrative_strlist) - 1:
                    x_coord = local_narrative_strlist[idx * 2]
                    y_coord = local_narrative_strlist[idx * 2 + 1]
                    local_narrative_str += f"<trj>{x_coord},{y_coord}</trj><word>{caption['utterance']}</word>"
        else:
            local_narrative_str = "None"

        ret = {
            "image": img_path,
            "conversations": [
                {
                    "from": "human",
                    "value": prompt  #.replace("<local_narrative>", "")
                },
                {
                    "from": "gpt",
                    # "value": DEFAULT_LNC_TOKEN + local_narrative_str
                    "value": local_narrative_str
                }
            ]
        }
        if i == 0:
            print(prompt)
            print(expr)

        return ret

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128
            sample = json.loads(sample)
            # if "segmentations" in sample:
            #     seg_num = len(sample['segmentations'])
            # else:
            #     # seg_num = len(sample['seg'])
            #     seg_num = len(sample['seg'])
            length_list.append(len(sample["caption"].split()) + img_tokens)
        return length_list
    
    @property
    def modality_types(self):
        # 在此处，纯文本的长度赋予-cur_len进行区分，同理，对于region，可进行该处理
        type_list = []
        for _, _ in enumerate(self.list_data_dict):
            cur_type = 4 # 1:LAZY 2:RES 3:GRES 4:IS 5:SEG 6:REC 7:OVN 8:LNC 9:LNS
            type_list.append(cur_type)
        return type_list

    @property
    def modality_lengths(self):  
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 128
            sample = json.loads(sample)
            # if "segmentations" in sample:
            #     seg_num = len(sample['segmentations'])
            # else:
            #     # seg_num = len(sample['seg'])
            # seg_num = 20
            # seg_num = len(sample['traces'][0])
            length_list.append(len(sample["caption"].split()) + img_tokens)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i % 50 == 0:
            aaa = np.int16(i) / np.int16(0)
        #得到原始的item
        sources = self.transform2conv(i)
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'mask' in sources[0]:
            mask = sources[0]['mask']
            # ori_mask = sources[0]['ori_mask']
            has_mask = True
        else:
            has_mask = False

        if 'image' in sources[0]:
            image_file = sources[0]['image']
            if self.image_folder is None:
                image_folder = self.data_args.image_folder
            else:
                image_folder = self.image_folder
            processor = self.data_args.image_processor #CLIP中的process，将图像resize norm等
            try:
                image = Image.open(os.path.join(image_folder, image_file.replace('COCO_train2014_', ''))).convert('RGB')
            except:
                image = Image.fromarray(cv2.cvtColor(cv2.imread(os.path.join(image_folder, image_file)), cv2.COLOR_BGR2RGB))
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) #预先padding成为square
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            elif not self.data_args.image_processor.do_resize:
                image = self.resize(image)
                if self.debug:
                    image = processor.preprocess(image, do_normalize=False, return_tensors='pt')['pixel_values'][0]
                else:
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
            has_image = True
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
            has_image = False
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=has_image,
            )
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0],
                             )
        if has_mask:
            data_dict['mask'] = mask
            # data_dict['ori_mask'] = ori_mask
        # image exist in the data
        data_dict['source'] = sources
        data_dict['image_path'] = os.path.join(image_folder, image_file)
        if has_image:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict