import os
import pickle
from typing import Literal, Optional, Union, Tuple

import torch
from torch import nn
from clients.base import Client
from transformers import SegformerForSemanticSegmentation
import evaluate
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import numpy as np
from tqdm import trange
from PIL import Image, ImageOps, ImageFilter
import random
from peft import LoraConfig, get_peft_model

from transformers.modeling_outputs import SemanticSegmenterOutput

from utils.model_ops import print_trainable_parameters


metric = evaluate.load("mean_iou", keep_in_memory=True)


class customSegformer(torch.nn.Module):
    def __init__(self, model_name, output_size=256):
        super().__init__()
        classes = ["background",
            "airplane",
            "bicycle",
            "bird",
            "boat",
            "bottle",
            "bus",
            "car",
            "cat",
            "chair",
            "cow",
            "diningtable",
            "dog",
            "horse",
            "motorcycle",
            "person",
            "potted-plant",
            "sheep",
            "sofa",
            "train",
            "tv"]
        id2label = {i: label for i, label in enumerate(classes)}
        label2id = {label: i for i, label in enumerate(classes)}

        self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name, num_labels=21, ignore_mismatched_sizes=True, id2label=id2label, label2id=label2id)

        ############### Representations from Last Encoder Layer ###############

        self.hidden_dim = self.segformer.config.hidden_sizes[-1]
        self.projection_layer = nn.Linear(in_features=self.hidden_dim, out_features=output_size)

        #######################################################################

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        projection: Optional[bool] = False
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        
        outputs = self.segformer(pixel_values=pixel_values, labels=labels)
        
        if projection:

            outputs_ = self.segformer.segformer(
                pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=True,  # we need the intermediate hidden states
                return_dict=return_dict,
            )

            ############### Representations from Last Encoder Layer ###############

            sequence_output = outputs_[0]

            batch_size = sequence_output.shape[0]

            if self.segformer.config.reshape_last_stage:
                last_output = sequence_output.permute(0, 2, 3, 1)

            last_output = last_output.reshape(batch_size, -1, self.hidden_dim)

            last_output = last_output.mean(dim=1)

            logits = self.projection_layer(last_output)

            #######################################################################

        else: 
            logits = outputs.logits

        return SemanticSegmenterOutput(
            loss=outputs.loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )




def semantic_segmenration_model(base_path: str, proj_dim, lora_rank):

    model = customSegformer(base_path, output_size=proj_dim)

    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["query", "value", "key"],
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    for layer_name, param in model.named_parameters():
        if ("decode_head.classifier" in layer_name) or ("projection_layer" in layer_name):
            param.requires_grad = True

    print_trainable_parameters(model)

    return model



def semantic_segmenration_train_function(
    num_epochs, model, optimizer, train_dataloader, device
):
    def train_function():
        for epoch in range(num_epochs):
            print("Epoch:", epoch)
            for idx, batch in enumerate(train_dataloader):
                # get the inputs;
                img, mask = batch
                img = img.to(device)
                mask = mask.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(pixel_values=img, labels=mask)
                loss, logits = outputs.loss, outputs.logits

                loss.backward()
                optimizer.step()
        return loss

    return train_function


def semantic_segmentation_eval_function(model, test_loader, train_loader, device):
    def eval_function(evaluation_set: Literal["train", "test"]):
        if evaluation_set == "test":
            loader = test_loader
        elif evaluation_set == "train":
            loader = train_loader
        else:
            raise NotImplementedError

        model.to(device) 

        for idx, batch in enumerate(loader):
            with torch.no_grad():
                # get the inputs;
                img, mask = batch
                img = img.to(device)
                mask = mask.to(device)

                # forward
                outputs = model(pixel_values=img, labels=mask)
                loss, logits = outputs.loss, outputs.logits

                # evaluate
                upsampled_logits = nn.functional.interpolate(
                    logits, size=mask.shape[-2:], mode="bilinear", align_corners=False
                )
                predicted = upsampled_logits.argmax(dim=1)

                # note that the metric expects predictions + mask as numpy arrays
                metric.add_batch(
                    predictions=predicted.detach().cpu().numpy(),
                    references=mask.detach().cpu().numpy(),
                )

                # currently using _compute instead of compute
                # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576

                metrics = metric._compute(
                    predictions=predicted.cpu(),
                    references=mask.cpu(),
                    num_labels=COCOSegmentation.NUM_CLASS,
                    ignore_index=255,
                    reduce_labels=False,  # we've already reduced the labels ourselves
                )

                return {
                    "mean_iou": metrics["mean_iou"],
                    "mean_accuracy": metrics["mean_accuracy"],
                }

    return eval_function


class SegmentationDataset(object):
    """Segmentation Base Dataset"""

    def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
        super(SegmentationDataset, self).__init__()
        self.root = root
        self.transform = transform
        self.split = split
        self.mode = mode if mode is not None else split
        self.base_size = base_size
        self.crop_size = crop_size

    def _val_sync_transform(self, img, mask):
        outsize = self.crop_size
        short_size = outsize
        w, h = img.size
        if w > h:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - outsize) / 2.0))
        y1 = int(round((h - outsize) / 2.0))
        x1 = int(round((w - outsize) / 2.0))
        y1 = int(round((h - outsize) / 2.0))
        img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
        mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
        # final transform
        img, mask = self._img_transform(img), self._mask_transform(mask)
        return img, mask

    def _sync_transform(self, img, mask):
        # random mirror
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
        crop_size = self.crop_size
        # random scale (short edge)
        short_size = random.randint(
            int(self.base_size * 0.5), int(self.base_size * 2.0)
        )
        w, h = img.size
        if h > w:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        else:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # pad crop
        if short_size < crop_size:
            padh = crop_size - oh if oh < crop_size else 0
            padw = crop_size - ow if ow < crop_size else 0
            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
        # random crop crop_size
        w, h = img.size
        x1 = random.randint(0, w - crop_size)
        y1 = random.randint(0, h - crop_size)
        img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
        mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
        # gaussian blur as in PSP
        if random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
        # final transform
        img, mask = self._img_transform(img), self._mask_transform(mask)
        return img, mask

    def _img_transform(self, img):
        return np.array(img)

    def _mask_transform(self, mask):
        return np.array(mask).astype("int32")

    @property
    def num_class(self):
        """Number of categories."""
        return self.NUM_CLASS

    @property
    def pred_offset(self):
        return 0


class COCOSegmentation(SegmentationDataset):
    """COCO Semantic Segmentation Dataset for VOC Pre-training.

    Parameters
    ----------
    root : string
        Path to ADE20K folder. Default is './datasets/coco'
    split: string
        'train', 'val' or 'test'
    transform : callable, optional
        A function that transforms the image
    Examples
    --------
    >>> from torchvision import transforms
    >>> import torch.utils.data as data
    >>> # Transforms for Normalization
    >>> input_transform = transforms.Compose([
    >>>     transforms.ToTensor(),
    >>>     transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
    >>> ])
    >>> # Create Dataset
    >>> trainset = COCOSegmentation(split='train', transform=input_transform)
    >>> # Create Training Loader
    >>> train_data = data.DataLoader(
    >>>     trainset, 4, shuffle=True,
    >>>     num_workers=4)
    """

    CAT_LIST = [
        0,
        5,
        2,
        16,
        9,
        44,
        6,
        3,
        17,
        62,
        21,
        67,
        18,
        19,
        4,
        1,
        64,
        20,
        63,
        7,
        72,
    ]
    NUM_CLASS = 21

    def __init__(self, root, ann_file, mode=None, transform=None, **kwargs):
        super(COCOSegmentation, self).__init__(
            root, "train", mode, transform, crop_size=512, **kwargs
        )
        # lazy import pycocotools
        from pycocotools.coco import COCO
        from pycocotools import mask

        self.mode = mode
        self.root = root
        from pathlib import Path

        path = Path(ann_file)
        idx_fine_name = ann_file.split("/")[-1].rstrip(".json")
        ids_file = os.path.join(path.parent.absolute(), idx_fine_name + "_ids.mx")
        self.coco = COCO(ann_file)
        self.coco_mask = mask
        if os.path.exists(ids_file):
            with open(ids_file, "rb") as f:
                self.ids = pickle.load(f)
        else:
            ids = list(self.coco.imgs.keys())
            self.ids = self._preprocess(ids, ids_file)
        self.transform = transform

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        #print(coco.imgs)
        #if img_id not in coco.imgs:
        #    raise KeyError(f"Image ID {img_id} not found in COCO dataset!")
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata["file_name"]
        img = Image.open(os.path.join(self.root, path)).convert("RGB")
        path = img_metadata["file_name"]
        img = Image.open(os.path.join(self.root, path)).convert("RGB")
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        mask = Image.fromarray(
            self._gen_seg_mask(
                cocotarget, img_metadata["height"], img_metadata["width"]
            )
        )
        # synchrosized transform
        if self.mode == "train":
            img, mask = self._sync_transform(img, mask)
        elif self.mode == "val":
            img, mask = self._val_sync_transform(img, mask)
        else:
            assert self.mode == "testval"
            assert self.mode == "testval"
            img, mask = self._img_transform(img), self._mask_transform(mask)
        # general resize, normalize and toTensor
        if self.transform is not None:
            img = self.transform(img)
        return img, mask

    def _mask_transform(self, mask):
        return torch.LongTensor(
            np.array(mask).astype("int32"),
        )

    def _gen_seg_mask(self, target, h, w):
        mask = np.zeros((h, w), dtype=np.uint8)
        coco_mask = self.coco_mask
        for instance in target:
            rle = coco_mask.frPyObjects(instance["segmentation"], h, w)
            rle = coco_mask.frPyObjects(instance["segmentation"], h, w)
            m = coco_mask.decode(rle)
            cat = instance["category_id"]
            cat = instance["category_id"]
            if cat in self.CAT_LIST:
                c = self.CAT_LIST.index(cat)
            else:
                continue
            if len(m.shape) < 3:
                mask[:, :] += (mask == 0) * (m * c)
            else:
                mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(
                    np.uint8
                )
        return mask

    def _preprocess(self, ids, ids_file):
        print(
            "Preprocessing mask, this will take a while."
            + "But don't worry, it only run once for each split."
        )
        tbar = trange(len(ids))
        new_ids = []
        for i in tbar:
            img_id = ids[i]
            cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
            img_metadata = self.coco.loadImgs(img_id)[0]
            mask = self._gen_seg_mask(
                cocotarget, img_metadata["height"], img_metadata["width"]
            )
            # more than 1k pixels
            if (mask > 0).sum() > 1000:
                new_ids.append(img_id)
            tbar.set_description(
                "Doing: {}/{}, got {} qualified images".format(
                    i, len(ids), len(new_ids)
                )
            )
        print("Found number of qualified images: ", len(new_ids))
        with open(ids_file, "wb") as f:
            pickle.dump(new_ids, f)
        return new_ids

    def __len__(self) -> int:
        return len(self.ids)

    @property
    def classes(self):
        """Category names."""
        return (
            "background",
            "airplane",
            "bicycle",
            "bird",
            "boat",
            "bottle",
            "bus",
            "car",
            "cat",
            "chair",
            "cow",
            "diningtable",
            "dog",
            "horse",
            "motorcycle",
            "person",
            "potted-plant",
            "sheep",
            "sofa",
            "train",
            "tv",
        )


transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.RandomHorizontalFlip(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


def create_semantic_segmentation_client(
    args,
    base_models_path,
    train_ann_files,
    test_ann_files,
    root_train_image_folder,
    root_val_image_folder,
    public_loader,
    lr,
    num_local_epochs,
    device,
    train_func,
    train_args,
    server_lr,
):
    model = semantic_segmenration_model(base_models_path, args.proj_dim, args.rank)
    local_optmizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Create dataset
    train_dataset = COCOSegmentation(
        root=root_train_image_folder,
        ann_file=train_ann_files,
        mode="train",
        transform=transform_train,
    )
    test_dataset = COCOSegmentation(
        root=root_val_image_folder,
        ann_file=test_ann_files,
        mode="val",
        transform=transform_test,
    )

    # Create dataloader
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=False,
    )

    return Client(
        task="semantic_segmentation",
        model=model,
        local_optimizer=local_optmizer,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        train_loader=train_loader,
        test_loader=test_loader,
        public_loader=public_loader,
        local_train_func=train_func(
            num_local_epochs,
            model,
            local_optmizer,
            train_loader,
            public_loader,
            device,
            train_args,
        ),
        local_eval_fun=semantic_segmentation_eval_function(
            model, test_loader, train_loader, device
        ),
        train_args=train_args,
        server_lr=server_lr,
    )
