from typing import Tuple, Optional, Dict
import logging
import os
import shutil
from os import path
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np

import pycocotools.mask as mask_util
from threading import Thread
from queue import Queue
from dataclasses import dataclass
import copy

from feature_extractor.cutie.cutie.utils.pano_utils import ID2RGBConverter
from feature_extractor.cutie.cutie.utils.palette import davis_palette_np
from feature_extractor.cutie.cutie.inference.object_manager import ObjectManager
from feature_extractor.cutie.cutie.inference.object_info import ObjectInfo

log = logging.getLogger()

try:
    import hickle as hkl
except ImportError:
    log.warning('Failed to import hickle. Fine if not using multi-scale testing.')


class ResultSaver:
    def __init__(self,
                 output_root,
                 video_name,
                 *,
                 dataset,
                 object_manager: ObjectManager,
                 use_long_id,
                 palette=None,
                 save_mask=True,
                 save_scores=False,
                 score_output_root=None,
                 visualize_output_root=None,
                 visualize=False,
                 init_json=None):
        self.output_root = output_root
        self.video_name = video_name
        self.dataset = dataset.lower()
        self.use_long_id = use_long_id
        self.palette = palette
        self.object_manager = object_manager
        self.save_mask = save_mask
        self.save_scores = save_scores
        self.score_output_root = score_output_root
        self.visualize_output_root = visualize_output_root
        self.visualize = visualize

        if self.visualize:
            if self.palette is not None:
                self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3)
            else:
                self.colors = davis_palette_np

        self.need_remapping = True
        self.json_style = None
        self.id2rgb_converter = ID2RGBConverter()

        if 'burst' in self.dataset:
            assert init_json is not None
            self.input_segmentations = init_json['segmentations']
            self.segmentations = [{} for _ in init_json['segmentations']]
            self.annotated_frames = init_json['annotated_image_paths']
            self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'}
            self.video_json['segmentations'] = self.segmentations
            self.json_style = 'burst'

        self.queue = Queue(maxsize=10)
        self.thread = Thread(target=save_result, args=(self.queue, ))
        self.thread.daemon = True
        self.thread.start()

    def process(self,
                prob: torch.Tensor,
                frame_name: str,
                resize_needed: bool = False,
                shape: Optional[Tuple[int, int]] = None,
                last_frame: bool = False,
                path_to_image: str = None):

        if resize_needed:
            prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,
                                                                                                 0]
        # Probability mask -> index mask
        mask = torch.argmax(prob, dim=0)
        if self.save_scores:
            # also need to pass prob
            prob = prob.cpu()
        else:
            prob = None

        # remap indices
        if self.need_remapping:
            new_mask = torch.zeros_like(mask)
            for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
                new_mask[mask == tmp_id] = obj.id
            mask = new_mask

        args = ResultArgs(saver=self,
                          prob=prob,
                          mask=mask.cpu(),
                          frame_name=frame_name,
                          path_to_image=path_to_image,
                          tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj),
                          obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id),
                          last_frame=last_frame)

        self.queue.put(args)

    def end(self):
        self.queue.put(None)
        self.queue.join()
        self.thread.join()


@dataclass
class ResultArgs:
    saver: ResultSaver
    prob: torch.Tensor
    mask: torch.Tensor
    frame_name: str
    path_to_image: str
    tmp_id_to_obj: Dict[int, ObjectInfo]
    obj_to_tmp_id: Dict[ObjectInfo, int]
    last_frame: bool


def save_result(queue: Queue):
    while True:
        args: ResultArgs = queue.get()
        if args is None:
            queue.task_done()
            break

        saver = args.saver
        prob = args.prob
        mask = args.mask
        frame_name = args.frame_name
        path_to_image = args.path_to_image
        tmp_id_to_obj = args.tmp_id_to_obj
        obj_to_tmp_id = args.obj_to_tmp_id
        last_frame = args.last_frame
        all_obj_ids = [k.id for k in obj_to_tmp_id]

        # record output in the json file
        if saver.json_style == 'burst':
            if frame_name in saver.annotated_frames:
                frame_index = saver.annotated_frames.index(frame_name)
                input_segments = saver.input_segmentations[frame_index]
                frame_segments = saver.segmentations[frame_index]

                for id in all_obj_ids:
                    if id in input_segments:
                        # if this frame has been given as input, just copy
                        frame_segments[id] = input_segments[id]
                        continue

                    segment = {}
                    segment_mask = (mask == id)
                    if segment_mask.sum() > 0:
                        coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy()))
                        segment['rle'] = coco_mask['counts'].decode('utf-8')
                        frame_segments[id] = segment

        # save the mask to disk
        if saver.save_mask:
            if saver.use_long_id:
                out_mask = mask.numpy().astype(np.uint32)
                rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8)
                for id in all_obj_ids:
                    _, image = saver.id2rgb_converter.convert(id)
                    obj_mask = (out_mask == id)
                    rgb_mask[obj_mask] = image
                out_img = Image.fromarray(rgb_mask)
            else:
                rgb_mask = None
                out_mask = mask.numpy().astype(np.uint8)
                out_img = Image.fromarray(out_mask)
                if saver.palette is not None:
                    out_img.putpalette(saver.palette)

            this_out_path = path.join(saver.output_root, saver.video_name)
            os.makedirs(this_out_path, exist_ok=True)
            out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png'))

        # save scores for multi-scale testing
        if saver.save_scores:
            this_out_path = path.join(saver.score_output_root, saver.video_name)
            os.makedirs(this_out_path, exist_ok=True)

            prob = (prob.detach().numpy() * 255).astype(np.uint8)

            if last_frame:
                tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()}
                hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w')

            hkl.dump(prob,
                     path.join(this_out_path, f'{frame_name[:-4]}.hkl'),
                     mode='w',
                     compression='lzf')

        if saver.visualize:
            if path_to_image is not None:
                image_np = np.array(Image.open(path_to_image))
            else:
                raise ValueError('Cannot visualize without path_to_image')

            if rgb_mask is None:
                # we need to apply a palette
                rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8)
                for id in all_obj_ids:
                    image = saver.colors[id]
                    obj_mask = (out_mask == id)
                    rgb_mask[obj_mask] = image

            alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5
            alpha = alpha[:, :, None]
            blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8)

            # find a place to save the visualization
            this_vis_path = path.join(saver.visualize_output_root, saver.video_name)
            os.makedirs(this_vis_path, exist_ok=True)
            Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg'))

        queue.task_done()


def make_zip(dataset, run_dir, exp_id, mask_output_root):
    if dataset.startswith('y'):
        # YoutubeVOS
        log.info('Making zip for YouTubeVOS...')
        shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
                            'Annotations')
    elif dataset == 'd17-test-dev':
        # DAVIS 2017 test-dev -- zip from within the Annotation folder
        log.info('Making zip for DAVIS test-dev...')
        shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
    elif dataset == 'mose-val':
        # MOSE validation -- same as DAVIS test-dev
        log.info('Making zip for MOSE validation...')
        shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root)
    elif dataset == 'lvos-test':
        # LVOS test -- same as YouTubeVOS
        log.info('Making zip for LVOS test...')
        shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir,
                            'Annotations')
    else:
        log.info(f'Not making zip for {dataset}.')
