from functools import partial
import numpy
import os
import re
import random
import signal
import csv
import settings
import numpy as np
from collections import OrderedDict
from imageio import imread
from multiprocessing import Pool, cpu_count
from multiprocessing.pool import ThreadPool
from scipy.ndimage.interpolation import zoom
from sklearn.model_selection import train_test_split
from torchvision import transforms
import torch
import pickle
from tqdm import tqdm
import json



from . import data_utils as du

from PIL import ImageEnhance


def load_csv(filename, readfields=None):
    def convert(value):
        if re.match(r"^-?\d+$", value):
            try:
                return int(value)
            except:
                pass
        if re.match(r"^-?[\.\d]+(?:e[+=]\d+)$", value):
            try:
                return float(value)
            except:
                pass
        return value

    with open(filename) as f:
        reader = csv.DictReader(f)
        result = [{k: convert(v) for k, v in row.items()} for row in reader]
        if readfields is not None:
            readfields.extend(reader.fieldnames)
    return result


class AbstractSegmentation:
    def all_names(self, category, j):
        raise NotImplementedError

    def size(self, split=None):
        return 0

    def filename(self, i):
        raise NotImplementedError

    def metadata(self, i):
        return self.filename(i)

    @classmethod
    def resolve_segmentation(cls, m):
        return {}

    def name(self, category, i):

        all_names = self.all_names(category, i)
        return all_names[0] if len(all_names) else ""

    def segmentation_data(self, category, i, c=0, full=False):

        segs = self.resolve_segmentation(self.metadata(i), categories=[category])
        if category not in segs:
            return 0
        data = segs[category]
        if not full and len(data.shape) >= 3:
            return data[0]
        return data


class SegmentationData(AbstractSegmentation):


    def __init__(self, directory, categories=None, require_all=False):
        directory = os.path.expanduser(directory)
        self.directory = directory
        with open(os.path.join(directory, settings.INDEX_FILE)) as f:
            self.image = [decode_index_dict(r) for r in csv.DictReader(f)]
        with open(os.path.join(directory, "category.csv")) as f:
            self.category = OrderedDict()
            for row in csv.DictReader(f):
                if categories and row["name"] in categories:
                    self.category[row["name"]] = row
        categories = self.category.keys()
        with open(os.path.join(directory, "label.csv")) as f:
            label_data = [decode_label_dict(r) for r in csv.DictReader(f)]
        self.label = build_dense_label_array(label_data)
        # Reverse label
        self.rev_label = {o["name"]: i for i, o in enumerate(self.label)}

        filter_fn = partial(
            index_has_all_data if require_all else index_has_any_data,
            categories=categories,
        )
        self.image = [row for row in self.image if filter_fn(row)]

        self.category_map = {}
        self.category_unmap = {}
        self.category_label = {}
        for cat in self.category:
            with open(os.path.join(directory, "c_%s.csv" % cat)) as f:
                c_data = [decode_label_dict(r) for r in csv.DictReader(f)]
            self.category_unmap[cat], self.category_map[cat] = build_numpy_category_map(
                c_data
            )
            self.category_label[cat] = build_dense_label_array(c_data, key="code")

        self.labelcat = du.onehot(self.primary_categories_per_index())

 
        scenes_fname = os.path.join(directory, "ade20k_scenes.json")
        if os.path.exists(scenes_fname):
            with open(scenes_fname, "r") as f:
                self.scenes = json.load(f)
        else:
            self.scenes = {}

    def primary_categories_per_index(ds):

        catmap = {}
        categories = ds.category_names()
        for cat in categories:
            imap = ds.category_index_map(cat)
            if len(imap) < ds.label_size(None):
                imap = np.concatenate(
                    (imap, np.zeros(ds.label_size(None) - len(imap), dtype=imap.dtype))
                )
            catmap[cat] = imap
        result = []
        for i in range(ds.label_size(None)):
            maxcov, maxcat = max(
                (ds.coverage(cat, catmap[cat][i]) if catmap[cat][i] else 0, ic)
                for ic, cat in enumerate(categories)
            )
            result.append(maxcat)
        return np.array(result)

    def all_names(self, category, j):
        
        if category is not None:
            j = self.category_unmap[category][j]
        return [self.label[j]["name"]] + self.label[j]["syns"]

    def size(self, split=None):
 
        if split is None:
            return len(self.image)
        return len([im for im in self.image if im["split"] == split])

    def filename(self, i):

        return os.path.join(self.directory, "images", self.image[i]["image"])

    def scene(self, i):
        img_basename = os.path.basename(self.image[i]["image"])
        return self.scenes.get(img_basename, "unk")

    def split(self, i):

        return self.image[i]["split"]

    def metadata(self, i):

        return self.directory, self.image[i]

    meta_categories = ["image", "split", "ih", "iw", "sh", "sw"]

    @classmethod
    def resolve_segmentation(cls, m, categories=None):

        directory, row = m
        result = {}
        for cat, d in row.items():
            if cat in cls.meta_categories:
                continue
            if not wants(cat, categories):
                continue
            if all(isinstance(data, int) for data in d):
                result[cat] = d
                continue
            out = numpy.empty((len(d), row["sh"], row["sw"]), dtype=numpy.int16)
            for i, channel in enumerate(d):
                if isinstance(channel, int):
                    out[i] = channel
                else:
                    rgb = imread(os.path.join(directory, "images", channel))
                    out[i] = rgb[:, :, 0] + rgb[:, :, 1] * 256
            result[cat] = out
        return result, (row["sh"], row["sw"])

    def label_size(self, category=None):

        if category is None:
            return len(self.label)
        else:
            return len(self.category_unmap[category])

    def name(self, category, j):

        if category is not None:
            j = self.category_unmap[category][j]
        return self.label[j]["name"]

    def rev_name(self, name):

        return self.rev_label[name]

    def frequency(self, category, j):

        if category is not None:
            return self.category_label[category][j]["frequency"]
        return self.label[j]["frequency"]

    def coverage(self, category, j):

        if category is not None:
            return self.category_label[category][j]["coverage"]
        return self.label[j]["coverage"]

    def category_names(self):

        return list(self.category.keys())

    def category_frequency(self, category):

        return float(self.category[category]["frequency"])

    def primary_categories_per_index(self, categories=None):

        if categories is None:
            categories = self.category_names()
        
        catmap = {}
        for cat in categories:
            imap = self.category_index_map(cat)
            if len(imap) < self.label_size(None):
                imap = numpy.concatenate(
                    (
                        imap,
                        numpy.zeros(
                            self.label_size(None) - len(imap), dtype=imap.dtype
                        ),
                    )
                )
            catmap[cat] = imap
        
        result = []
        for i in range(self.label_size(None)):
            maxcov, maxcat = max(
                (self.coverage(cat, catmap[cat][i]) if catmap[cat][i] else 0, ic)
                for ic, cat in enumerate(categories)
            )
            result.append(maxcat)
        
        return numpy.array(result)

    def segmentation_data(self, category, i, c=0, full=False, out=None):

        row = self.image[i]
        data_channels = row.get(category, ())
        if c >= len(data_channels):
            channel = 0  
        else:
            channel = data_channels[c]
        if out is None:
            out = numpy.empty((row["sh"], row["sw"]), dtype=numpy.int16)
        if isinstance(channel, int):
            if not full:
                channel = self.category_map[category][channel]
            out[:, :] = channel  
            return out
        png = imread(os.path.join(self.directory, "images", channel))
        if full:

            out[...] = png[:, :, 0] + png[:, :, 1] * 256
        else:

            catmap = self.category_map[category]
            out[...] = catmap[png[:, :, 0] + png[:, :, 1] * 256]
        return out

    def full_segmentation_data(self, i, categories=None, max_depth=None, out=None):

        row = self.image[i]
        if categories:
            groups = [d for cat, d in row.items() if cat in categories and d]
        else:
            groups = [
                d for cat, d in row.items() if d and (cat not in self.meta_categories)
            ]
        depth = sum(len(c) for c in groups)
        if max_depth is not None:
            depth = min(depth, max_depth)

        if out is None:
            out = numpy.empty((depth, row["sh"], row["sw"]), dtype=numpy.int16)
        i = 0

        for group in groups:
            for channel in group:
                if isinstance(channel, int):
                    out[i] = channel
                else:
                    png = imread(os.path.join(self.directory, "images", channel))
                    out[i] = png[:, :, 0] + png[:, :, 1] * 256
                i += 1
                if i == depth:
                    return out

        assert False

    def category_index_map(self, category):
        return numpy.array(self.category_map[category])


def build_dense_label_array(label_data, key="number", allow_none=False):

    result = [None] * (max([d[key] for d in label_data]) + 1)
    for d in label_data:
        result[d[key]] = d
    # Fill in none
    if not allow_none:
        example = label_data[0]

        def make_empty(k):
            return dict((c, k if c is key else type(v)()) for c, v in example.items())

        for i, d in enumerate(result):
            if d is None:
                result[i] = dict(make_empty(i))
    return result


def build_numpy_category_map(map_data, key1="code", key2="number"):

    results = list(
        numpy.zeros((max([d[key] for d in map_data]) + 1), dtype=numpy.int16)
        for key in (key1, key2)
    )
    for d in map_data:
        results[0][d[key1]] = d[key2]
        results[1][d[key2]] = d[key1]
    return results


def decode_label_dict(row):
    result = {}
    for key, val in row.items():
        if key == "category":
            result[key] = dict(
                (c, int(n))
                for c, n in [
                    re.match("^([^(]*)\(([^)]*)\)$", f).groups() for f in val.split(";")
                ]
            )
        elif key == "name":
            result[key] = val
        elif key == "syns":
            result[key] = val.split(";")
        elif re.match("^\d+$", val):
            result[key] = int(val)
        elif re.match("^\d+\.\d*$", val):
            result[key] = float(val)
        else:
            result[key] = val
    return result


def decode_index_dict(row):
    result = {}
    for key, val in row.items():
        if key in ["image", "split"]:
            result[key] = val
        elif key in ["sw", "sh", "iw", "ih"]:
            result[key] = int(val)
        else:
            item = [s for s in val.split(";") if s]
            for i, v in enumerate(item):
                if re.match("^\d+$", v):
                    item[i] = int(v)
            result[key] = item
    return result


def index_has_any_data(row, categories):
    for c in categories:
        for data in row[c]:
            if data:
                return True
    return False


def index_has_all_data(row, categories):
    for c in categories:
        cat_has = False
        for data in row[c]:
            if data:
                cat_has = True
                break
        if not cat_has:
            return False
    return True


class SegmentationPrefetcher:


    def __init__(
        self,
        segmentation,
        split=None,
        randomize=False,
        segmentation_shape=None,
        categories=None,
        once=False,
        start=None,
        end=None,
        batch_size=4,
        ahead=4,
        thread=False,
    ):

        self.segmentation = segmentation
        self.split = split
        self.randomize = randomize
        self.random = random.Random()
        if randomize is not True:
            self.random.seed(randomize)
        self.categories = categories
        self.once = once
        self.batch_size = batch_size
        self.ahead = ahead
 
        n_procs = cpu_count()
        if thread:
            self.pool = ThreadPool(processes=n_procs)
        else:
            original_sigint_handler = setup_sigint()
            self.pool = Pool(processes=n_procs, initializer=setup_sigint)
            restore_sigint(original_sigint_handler)

        if start is None:
            start = 0
        if end is None:
            end = segmentation.size()
        self.indexes = range(start, end)
        if split:
            self.indexes = [i for i in self.indexes if segmentation.split(i) == split]
        if self.randomize:
            self.random.shuffle(self.indexes)
        self.index = 0
        self.result_queue = []
        self.segmentation_shape = segmentation_shape

        self.catmaps = [
            segmentation.category_index_map(cat) if cat != "image" else None
            for cat in categories
        ]

    def next_job(self):
 
        if self.index < 0:
            return None
        j = self.indexes[self.index]
        result = (
            j,
            self.segmentation.__class__,
            self.segmentation.metadata(j),
            self.segmentation.filename(j),
            self.categories,
            self.segmentation_shape,
        )
        self.index += 1
        if self.index >= len(self.indexes):
            if self.once:
                self.index = -1
            else:
                self.index = 0
                if self.randomize:

                    self.random.shuffle(self.indexes)
        return result

    def batches(self):

        while True:
            batch = self.fetch_batch()
            if batch is None:
                return
            yield batch

    def fetch_batch(self):

        try:
            self.refill_tasks()
            if len(self.result_queue) == 0:
                return None
            result = self.result_queue.pop(0)
            return result.get(31536000)
        except KeyboardInterrupt:
            print("Caught KeyboardInterrupt, terminating workers")
            self.pool.terminate()
            raise

    def fetch_tensor_batch(self, bgr_mean=None, global_labels=False):

        batch = self.fetch_batch()
        return self.form_caffe_tensors(batch, bgr_mean, global_labels)

    def tensor_batches(self, bgr_mean=None, global_labels=False):

        while True:
            batch = self.fetch_tensor_batch(
                bgr_mean=bgr_mean, global_labels=global_labels
            )
            if batch is None:
                return
            yield batch

    def form_caffe_tensors(self, batch, bgr_mean=None, global_labels=False):

        if batch is None:
            return None
        cats = [*self.categories, "scene"]
        batches = [[] for c in cats]
        for record in batch:
            default_shape = (1, record["sh"], record["sw"])
            for c, cat in enumerate(cats):
                if cat == "image":
                    # Normalize image with right RGB order and mean
                    batches[c].append(normalize_image(record[cat], bgr_mean))
                elif global_labels:
                    if cat == "scene":
                        if not record[cat]:
                            batches[c].append(np.array([-1]))
                        elif len(record[cat]) > 1:
                            print(f"Multiple scenes: {record['fn']} {record[cat]}")
                            batches[c].append(np.array(record[cat][0]))
                        else:
                            batches[c].append(np.array(record[cat]))
                    else:
                        batches[c].append(
                            normalize_label(record[cat], default_shape, flatten=True)
                        )
                else:
                    catmap = self.catmaps[c]
                    batches[c].append(
                        catmap[
                            normalize_label(record[cat], default_shape, flatten=True)
                        ]
                    )
        return [numpy.concatenate(tuple(m[numpy.newaxis] for m in b)) for b in batches]

    def refill_tasks(self):

        while len(self.result_queue) < self.ahead:
            data = []
            while len(data) < self.batch_size:
                job = self.next_job()
                if job is None:
                    break
                data.append(job)
            if len(data) == 0:
                return
            self.result_queue.append(self.pool.map_async(prefetch_worker, data))

    def close(self):
        while len(self.result_queue):
            result = self.result_queue.pop(0)
            if result is not None:
                result.wait(0.001)
        self.pool.close()
        self.pool.cancel_join_thread()


def prefetch_worker(d):
    if d is None:
        return None
    j, typ, m, fn, categories, segmentation_shape = d
    categories = ["scene", *categories]
    segs, shape = typ.resolve_segmentation(m, categories=categories)
    if segmentation_shape is not None:
        for k, v in segs.items():
            segs[k] = scale_segmentation(v, segmentation_shape)
        shape = segmentation_shape
    segs["sh"], segs["sw"] = shape
    segs["i"] = j
    segs["fn"] = fn
    if categories is None or "image" in categories:
        segs["image"] = imread(fn)
    return segs


def scale_segmentation(segmentation, dims, crop=False):

    shape = numpy.shape(segmentation)
    if len(shape) < 2 or shape[-2:] == dims:
        return segmentation
    peel = len(shape) == 2
    if peel:
        segmentation = segmentation[numpy.newaxis]
    levels = segmentation.shape[0]
    result = numpy.zeros((levels,) + dims, dtype=segmentation.dtype)
    ratio = (1,) + tuple(
        res / float(orig) for res, orig in zip(result.shape[1:], segmentation.shape[1:])
    )
    if not crop:
        safezoom(segmentation, ratio, output=result, order=0)
    else:
        ratio = max(ratio[1:])
        height = int(round(dims[0] / ratio))
        hmargin = (segmentation.shape[0] - height) // 2
        width = int(round(dims[1] / ratio))
        wmargin = (segmentation.shape[1] - height) // 2
        safezoom(
            segmentation[:, hmargin : hmargin + height, wmargin : wmargin + width],
            (1, ratio, ratio),
            output=result,
            order=0,
        )
    if peel:
        result = result[0]
    return result


def safezoom(array, ratio, output=None, order=0):

    dtype = array.dtype
    if array.dtype == numpy.float16:
        array = array.astype(numpy.float32)
    if array.shape[0] == 1:
        if output is not None:
            output = output[0, ...]
        result = zoom(array[0, ...], ratio[1:], output=output, order=order)
        if output is None:
            output = result[numpy.newaxis]
    else:
        result = zoom(array, ratio, output=output, order=order)
        if output is None:
            output = result
    return output.astype(dtype)


def setup_sigint():
    import threading

    if not isinstance(threading.current_thread(), threading._MainThread):
        return None
    return signal.signal(signal.SIGINT, signal.SIG_IGN)


def restore_sigint(original):
    import threading

    if not isinstance(threading.current_thread(), threading._MainThread):
        return
    if original is None:
        original = signal.SIG_DFL
    signal.signal(signal.SIGINT, original)


def wants(what, option):
    if option is None:
        return True
    return what in option


def normalize_image(rgb_image, bgr_mean):

    img = numpy.array(rgb_image, dtype=numpy.float32)
    if img.ndim == 2:
        img = numpy.repeat(img[:, :, None], 3, axis=2)
    img = img[:, :, ::-1]
    if bgr_mean is not None:
        img -= bgr_mean
    img = img.transpose((2, 0, 1))
    return img


def normalize_label(label_data, shape, flatten=False):

    dims = len(numpy.shape(label_data))
    if dims <= 2:
        # Scalar data on this channel: fill shape
        if dims == 1:
            if flatten:
                label_data = label_data[0] if len(label_data) else 0
            else:
                return (
                    numpy.ones(shape, dtype=numpy.int16)
                    * numpy.asarray(label_data, dtype=numpy.int16)[
                        :, numpy.newaxis, numpy.newaxis
                    ]
                )
        return numpy.full(shape, label_data, dtype=numpy.int16)
    else:
        if dims == 3:
            if flatten:
                label_data = label_data[0]
            else:
                return label_data
        return label_data[numpy.newaxis]


if __name__ == "__main__":
    data = SegmentationData("broden1_227")
    pd = SegmentationPrefetcher(
        data, categories=data.category_names() + ["image"], once=True
    )
    bs = pd.batches().next()
