from typing import Dict, Any, Optional, Union
import torch
import pathlib

from collections import defaultdict
import PIL.Image

from rigl_torch.datasets import _data_stem

import torchvision

# We disable the beta transforms warning as it will print many times
torchvision.disable_beta_transforms_warning()
# from torchvision import datasets  # noqa: E402
import torchvision.transforms.v2 as transforms  # noqa: E402
from ._coco_detection_v2 import CocoDetectionV2  # noqa: E402


class CocoSegmentationDataStem(_data_stem.ABCDataStem):
    def __init__(
        self,
        cfg: Dict[str, Any],
        data_path_override: Optional[Union[str, pathlib.Path]] = None,
    ):
        super().__init__(cfg, data_path_override)

    def _get_datasets(self):
        train_transformer = self._get_transform()
        test_transformer = self._get_test_transform()

        # train_dataset = datasets.CocoDetection(
        #     root=self.data_path / "train2017",
        #     annFile=self.data_path / "annotations" / "instances_train2017.json",  # noqa
        #     transforms=train_transformer,
        # )
        # test_dataset = datasets.CocoDetection(
        #     root=self.data_path / "val2017",
        #     annFile=self.data_path / "annotations" / "instances_val2017.json",
        #     transforms=test_transformer,
        # )
        train_dataset = CocoDetectionV2(
            root=self.data_path / "train2017",
            annFile=self.data_path / "annotations" / "instances_train2017.json",
            transforms=train_transformer,
            no_add_ids=_TRAIN_NO_ANN_IDS,
        )
        test_dataset = CocoDetectionV2(
            root=self.data_path / "val2017",
            annFile=self.data_path / "annotations" / "instances_val2017.json",
            transforms=test_transformer,
            no_add_ids=_VAL_NO_ANN_IDS,
        )
        # NOTE: We need to wrap datasets for v2 transformers.
        # See: https://pytorch.org/vision/0.15/auto_examples/plot_transforms_v2_e2e.html  # noqa
        # NOTE: Not while we are using the v2 dataset.
        # See ./_coco_detection_v2.py
        # train_dataset = datasets.wrap_dataset_for_transforms_v2(train_dataset)
        # test_dataset = datasets.wrap_dataset_for_transforms_v2(test_dataset)
        self._append_collate_fn_to_dataloader_kwargs()
        return train_dataset, test_dataset

    def _get_transform(self):
        train_transform = transforms.Compose(
            [
                transforms.RandomPhotometricDistort(),
                transforms.RandomZoomOut(
                    fill=defaultdict(
                        lambda: 0, {PIL.Image.Image: (123, 117, 104)}
                    )
                ),
                # Needs to be introduced before dataset is wrapped in
                # CocoDetectionV2 class
                # PopulateNullAnnotations(),
                # transforms.RandomIoUCrop(),  # Deleting lots of bboxes
                transforms.RandomHorizontalFlip(),
                transforms.ToImageTensor(),
                transforms.ConvertImageDtype(torch.float32),
                # transforms.ClampBoundingBox(),
                # Doesn't work without all samples containing annotations
                transforms.SanitizeBoundingBox(),
            ]
        )
        return train_transform

    def _get_test_transform(self):
        test_transform = transforms.Compose(
            [
                transforms.ToImageTensor(),
                transforms.ConvertImageDtype(torch.float32),
            ]
        )
        return test_transform

    def _append_collate_fn_to_dataloader_kwargs(self) -> None:
        self.train_kwargs.update({"collate_fn": collate_fn})
        self.test_kwargs.update({"collate_fn": collate_fn})


def collate_fn(batch):
    return tuple(zip(*batch))


# NOTE: Not being used currently in favor of filtering IDs without anns
class PopulateNullAnnotations(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, sample):
        x, ann = sample
        keys = [
            "segmentation",
            "area",
            "iscrowd",
            "image_id",
            "bbox",
            "category_id",
            "id",
        ]
        tensor_keys = [
            "boxes",
            "masks",
            "labels",
        ]
        for k in keys:
            if k not in ann:
                ann.update({k: []})
        for k in tensor_keys:
            if k not in ann:
                ann.update({k: torch.Tensor([])})
        return (x, ann)


_VAL_NO_ANN_IDS = [
    25593,
    41488,
    42888,
    49091,
    58636,
    64574,
    98497,
    101022,
    121153,
    127135,
    173183,
    176701,
    198915,
    200152,
    226111,
    228771,
    240767,
    260657,
    261796,
    267946,
    268996,
    270386,
    278006,
    308391,
    310622,
    312549,
    320706,
    330554,
    344611,
    370999,
    374727,
    382734,
    402096,
    404601,
    447789,
    458790,
    461275,
    476491,
    477118,
    481404,
    502910,
    514540,
    528977,
    536343,
    542073,
    550939,
    556498,
    560371,
]

_TRAIN_NO_ANN_IDS = [
    262184,
    262189,
    262284,
    262623,
    508,
    524927,
    525020,
    1111,
    263478,
    1472,
    250,
    526104,
    1997,
    526292,
    264720,
    526890,
    264886,
    265207,
    3640,
    3692,
    3799,
    3941,
    266091,
    401670,
    266274,
    4308,
    266518,
    266611,
    4481,
    4517,
    267353,
    529578,
    268496,
    6426,
    530783,
    268693,
    268735,
    530894,
    268770,
    531232,
    576354,
    269372,
    531617,
    7625,
    7823,
    271198,
    533896,
    9759,
    9778,
    9809,
    10108,
    10263,
    10420,
    10440,
    534771,
    534918,
    273123,
    11076,
    274233,
    536587,
    349579,
    274957,
    13035,
    537684,
    13466,
    275710,
    537918,
    13789,
    276267,
    276645,
    276731,
    539016,
    539390,
    277329,
    15404,
    539755,
    15830,
    540388,
    540476,
    16689,
    278907,
    541094,
    16903,
    279103,
    279263,
    280413,
    542614,
    280825,
    280879,
    281035,
    543236,
    134236,
    281188,
    281582,
    281939,
    544188,
    544597,
    282586,
    20490,
    172971,
    545006,
    283147,
    545310,
    21166,
    21382,
    309496,
    284128,
    22098,
    284351,
    22456,
    22559,
    547072,
    23017,
    547760,
    285717,
    547903,
    23648,
    135042,
    548171,
    24181,
    548690,
    24499,
    549012,
    25378,
    287822,
    288131,
    26094,
    26266,
    550869,
    550878,
    26767,
    289071,
    551550,
    289943,
    552135,
    28156,
    552776,
    28645,
    28662,
    28806,
    291056,
    291149,
    29056,
    553420,
    29594,
    30175,
    292572,
    573113,
    442078,
    555470,
    31524,
    293671,
    293946,
    31926,
    294370,
    294431,
    557263,
    5500,
    557387,
    33123,
    267680,
    33422,
    167819,
    33442,
    33554,
    558269,
    34089,
    297058,
    559214,
    34985,
    559576,
    35422,
    297667,
    297736,
    35793,
    298190,
    36480,
    298836,
    298840,
    561119,
    299045,
    299371,
    299382,
    299441,
    37854,
    300090,
    562357,
    562582,
    38691,
    300929,
    563076,
    39068,
    224970,
    39475,
    564317,
    40474,
    302945,
    40962,
    565807,
    566025,
    303892,
    566103,
    304036,
    42345,
    566670,
    487702,
    567234,
    305159,
    531488,
    305871,
    43947,
    306477,
    568863,
    45075,
    307264,
    569451,
    45335,
    45668,
    570045,
    45822,
    570207,
    46633,
    308828,
    138859,
    571242,
    309222,
    47396,
    309571,
    572100,
    572546,
    572585,
    310688,
    48546,
    573053,
    576119,
    49255,
    139284,
    184965,
    49725,
    311877,
    49741,
    49883,
    574385,
    50637,
    51652,
    576017,
    51730,
    139695,
    314068,
    533006,
    52726,
    379190,
    315110,
    577748,
    53733,
    316091,
    54652,
    317120,
    317130,
    317575,
    55559,
    55567,
    55776,
    318144,
    318596,
    581087,
    28095,
    402809,
    319749,
    57977,
    58133,
    58554,
    59250,
    321603,
    59476,
    59550,
    359465,
    577949,
    315846,
    60434,
    322887,
    61045,
    61567,
    61575,
    272421,
    324094,
    324460,
    62805,
    62824,
    325125,
    325357,
    325368,
    325690,
    10654,
    64356,
    326613,
    326793,
    578819,
    578852,
    65380,
    327802,
    535233,
    65916,
    328084,
    328098,
    579023,
    66543,
    329462,
    579247,
    330535,
    68715,
    68838,
    69373,
    331600,
    69514,
    331876,
    69911,
    70125,
    332585,
    361323,
    333198,
    71516,
    333841,
    71879,
    334603,
    334642,
    465057,
    203191,
    72912,
    72978,
    335669,
    405483,
    335826,
    143425,
    336777,
    336873,
    12494,
    75083,
    75256,
    337506,
    75426,
    75481,
    75493,
    12641,
    338067,
    76150,
    37458,
    76923,
    339192,
    143995,
    339740,
    77849,
    340119,
    340375,
    340781,
    537397,
    78947,
    79331,
    79362,
    79671,
    79913,
    493956,
    342335,
    342624,
    342998,
    343035,
    81107,
    82388,
    344618,
    344705,
    344730,
    82756,
    345063,
    345155,
    83246,
    345391,
    345711,
    346061,
    84018,
    346615,
    84638,
    84819,
    347007,
    276311,
    101623,
    348853,
    86818,
    86831,
    86836,
    349083,
    349097,
    349352,
    87847,
    350334,
    552870,
    430665,
    88517,
    88768,
    529570,
    89485,
    90026,
    90171,
    90280,
    352564,
    160137,
    91372,
    91492,
    91705,
    354041,
    102729,
    92554,
    92604,
    59211,
    93994,
    94148,
    356673,
    356834,
    94792,
    404696,
    103341,
    90479,
    357948,
    321854,
    358795,
    96809,
    96923,
    359104,
    359184,
    359207,
    359276,
    234649,
    359774,
    97779,
    97785,
    98121,
    98155,
    98268,
    98679,
    16449,
    99010,
    99364,
    191329,
    361774,
    361831,
    362154,
    362257,
    362351,
    362696,
    362881,
    362986,
    101011,
    101073,
    29564,
    101535,
    364158,
    148102,
    102316,
    102899,
    279350,
    365487,
    365631,
    103910,
    366810,
    104829,
    104880,
    105246,
    367537,
    367998,
    106464,
    368750,
    368884,
    369618,
    107585,
    498563,
    107918,
    107941,
    370151,
    370305,
    108169,
    192817,
    370736,
    108697,
    371307,
    291702,
    371863,
    109942,
    110001,
    280532,
    542695,
    111290,
    111813,
    375096,
    113185,
    375363,
    375611,
    499525,
    114624,
    376835,
    377132,
    377234,
    115250,
    115566,
    115654,
    378632,
    378849,
    379037,
    379138,
    117664,
    118615,
    119438,
    151020,
    381842,
    381984,
    382115,
    382333,
    120235,
    382656,
    120683,
    121107,
    545235,
    383450,
    64024,
    122159,
    385265,
    123239,
    123424,
    149215,
    386200,
    124145,
    124240,
    386613,
    124509,
    124780,
    124983,
    125009,
    125084,
    125182,
    108259,
    125997,
    126210,
    388788,
    127104,
    195993,
    389811,
    405815,
    128740,
    391537,
    56695,
    129903,
    129988,
    130192,
    392534,
    130654,
    130712,
    393212,
    393762,
    371482,
    371484,
    394126,
    467050,
    132531,
    395124,
    395185,
    133693,
    554382,
    133885,
    396166,
    397089,
    415714,
    397187,
    397278,
    397287,
    135849,
    136173,
    398454,
    136779,
    136977,
    507195,
    285068,
    400309,
    138486,
    401212,
    401381,
    572025,
    139326,
    401623,
    402386,
    140603,
    402869,
    67163,
    140922,
    403104,
    140974,
    403279,
    141139,
    141316,
    403851,
    404462,
    404871,
    405104,
    143054,
    405459,
    405662,
    405856,
    143780,
    405945,
    406015,
    406217,
    109558,
    144480,
    406677,
    406709,
    242558,
    407030,
    199063,
    407976,
    505003,
    243018,
    409614,
    68295,
    409953,
    148527,
    410743,
    148622,
    410797,
    148703,
    149102,
    411349,
    412704,
    150616,
    150779,
    413090,
    413120,
    413222,
    413232,
    112605,
    287378,
    414089,
    414416,
    414754,
    152732,
    152858,
    415659,
    154349,
    416555,
    154885,
    154924,
    417689,
    200736,
    288137,
    156299,
    223432,
    156606,
    419106,
    420070,
    158614,
    159073,
    159480,
    421673,
    421970,
    160034,
    160298,
    423028,
    39847,
    162020,
    424528,
    162539,
    162768,
    424980,
    163055,
    425263,
    425439,
    425670,
    425933,
    464210,
    114741,
    427094,
    164999,
    427727,
    427992,
    428379,
    166260,
    428495,
    166524,
    354726,
    167118,
    429386,
    429568,
    429691,
    115475,
    431026,
    168905,
    431234,
    508844,
    431692,
    169722,
    432370,
    432373,
    432647,
    433129,
    171082,
    433546,
    433971,
    434129,
    435435,
    275245,
    173685,
    436048,
    174406,
    174902,
    175129,
    175193,
    547047,
    176149,
    176168,
    176193,
    176649,
    176943,
    177014,
    440269,
    440484,
    440771,
    179430,
    292081,
    441788,
    441863,
    38389,
    443294,
    181462,
    443871,
    443941,
    444302,
    182607,
    183617,
    445775,
    445898,
    446452,
    446646,
    184874,
    185437,
    447701,
    186441,
    380651,
    449082,
    449316,
    449546,
    187882,
    187934,
    450098,
    450343,
    188685,
    188832,
    451373,
    189740,
    452652,
    452746,
    452821,
    337653,
    453087,
    453286,
    453348,
    453566,
    191501,
    191661,
    192062,
    454230,
    454827,
    192764,
    455075,
    192974,
    193077,
    193451,
    193631,
    193704,
    193732,
    455882,
    181035,
    399262,
    317283,
    194574,
    194897,
    457219,
    195266,
    195286,
    195595,
    458540,
    207313,
    196701,
    267229,
    459408,
    459590,
    197774,
    198037,
    198514,
    201632,
    464261,
    464296,
    252194,
    444783,
    202848,
    15318,
    465211,
    465218,
    203221,
    203652,
    252406,
    514557,
    466511,
    204435,
    466935,
    466958,
    382191,
    468064,
    468935,
    470672,
    208708,
    209420,
    209630,
    209969,
    210766,
    428399,
    473495,
    211423,
    211439,
    211665,
    474147,
    474398,
    212675,
    35475,
    166596,
    554841,
    475929,
    476113,
    254124,
    214036,
    214461,
    215450,
    478982,
    217005,
    217118,
    479263,
    479316,
    217649,
    218557,
    443735,
    220160,
    482363,
    220527,
    429995,
    482826,
    220739,
    220932,
    221360,
    221618,
    221680,
    221828,
    222157,
    222330,
    222383,
    222757,
    124731,
    81043,
    224136,
    483334,
    486632,
    486769,
    224742,
    77318,
    487399,
    487516,
    225859,
    226128,
    226629,
    488924,
    489343,
    387416,
    227547,
    59953,
    227699,
    489907,
    489914,
    7592,
    228407,
    228415,
    228727,
    491114,
    491482,
    435360,
    518855,
    229782,
    229981,
    38435,
    230639,
    230795,
    252101,
    230968,
    231119,
    231840,
    7721,
    494812,
    495053,
    495235,
    532277,
    234138,
    234981,
    234988,
    514111,
    497257,
    235529,
    235783,
    498239,
    498371,
    236698,
    498969,
    498975,
    499446,
    499697,
    237718,
    301765,
    237860,
    500079,
    238006,
    238141,
    500780,
    501121,
    239505,
    239942,
    502325,
    502479,
    240436,
    240830,
    503200,
    241209,
    503483,
    127626,
    241595,
    503860,
    504025,
    242092,
    504524,
    433616,
    504886,
    242752,
    242969,
    505583,
    505637,
    243625,
    505865,
    346456,
    506066,
    244108,
    244160,
    506714,
    244636,
    244885,
    409890,
    507257,
    245373,
    507686,
    245560,
    507828,
    245810,
    246348,
    508771,
    303287,
    509036,
    246973,
    247177,
    509423,
    247504,
    247624,
    509792,
    509815,
    248464,
    522127,
    434813,
    249977,
    391200,
    250289,
    513111,
    25816,
    513149,
    251107,
    251132,
    8379,
    252122,
    260501,
    129507,
    252996,
    25885,
    515427,
    253435,
    253520,
    253688,
    516490,
    516542,
    254415,
    516777,
    95772,
    516974,
    517366,
    252918,
    518487,
    256655,
    257034,
    519359,
    258108,
    520316,
    258450,
    520737,
    521098,
    521132,
    259439,
    522013,
    259988,
    522527,
    260484,
    130950,
    523581,
    218357,
    524050,
]
