import os
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np
from sklearn.metrics import f1_score, hamming_loss
import torch
from torch import nn
from torchvision import datasets, transforms
import evaluate
from torch.utils.data import DataLoader
from PIL import Image
from transformers import ViTForImageClassification
from transformers.modeling_outputs import (
    ImageClassifierOutput,
)
from torch.nn import BCEWithLogitsLoss
from transformers.models.vit.configuration_vit import ViTConfig
from peft import LoraConfig, get_peft_model


from clients.base import Client
from utils.model_ops import print_trainable_parameters

metric = evaluate.load("mean_iou")


class VitForMultiLabelClassification(ViTForImageClassification):
    # TODO: Remove unnessary parts

    def __init__(self, config: ViTConfig, output_size=256) -> None:
        super().__init__(config)

        # Projection head
        self.projection_layer = nn.Linear(config.hidden_size, output_size)
        self.classifier_layer1 = nn.Linear(config.hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        projection: Optional[bool] = False,
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        if projection:
            logits = self.projection_layer(sequence_output[:, 0, :])

        else:
            x = self.classifier_layer1(sequence_output[:, 0, :])
            x = self.relu(x)
            logits = self.classifier(x)

        # move labels to correct device to enable model parallelism
        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def multi_label_classification_model(base_path: str, proj_dim, num_classes, lora_rank):
    model = VitForMultiLabelClassification.from_pretrained(base_path)

    model.classifier_layer1 = nn.Linear(model.config.hidden_size, proj_dim)
    model.projection_layer = nn.Linear(model.config.hidden_size, proj_dim)
    model.classifier = nn.Linear(proj_dim, num_classes)

    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["query", "value", "key"],
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    for layer_name, param in model.named_parameters():
        if ("classifier" in layer_name) or ("projection_layer" in layer_name):
            param.requires_grad = True

    print_trainable_parameters(model)

    return model


def multi_label_classification_eval_function(model, test_loader, train_loader, device):
    def eval_function(evaluation_set: Literal["train", "test"]):
        if evaluation_set == "test":
            loader = test_loader
        elif evaluation_set == "train":
            loader = train_loader
        else:
            raise NotImplementedError

        model.eval()
        model.to(device)
        all_labels = []
        all_predictions = []
        all_loss = []
        for idx, (pixel_values, labels) in enumerate(loader):
            with torch.no_grad():
                # get the inputs;
                pixel_values = pixel_values.to(device)
                labels = labels.to(device)

                # forward
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss, logits = outputs.loss, outputs.logits

                logits = nn.functional.sigmoid(logits)
                predictions = (logits > 0.5).float()

                all_labels.append(labels.cpu())
                all_predictions.append(predictions.cpu())
                all_loss.append(loss.cpu().numpy())

        all_labels = torch.cat(all_labels)
        all_predictions = torch.cat(all_predictions)
        all_loss = np.array(all_loss)

        hamming = hamming_loss(all_labels.numpy(), all_predictions.numpy())
        micro_f1 = f1_score(
            all_labels.numpy(), all_predictions.numpy(), average="micro"
        )
        avg_loss = np.mean(all_loss)
        model.to("cpu")

        return {
            "hamming": np.round(hamming, 3),
            "micro_f1": np.round(micro_f1, 3),
            "avg_loss": np.round(avg_loss, 3),
        }

    return eval_function


class CocoMultiLabelDataset(datasets.VisionDataset):
    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.

    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.

    Args:
        root (string): Root directory where images are downloaded to.
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.PILToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
    """

    def __init__(
        self,
        root: str,
        annFile: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
        num_classes: int = 91,
    ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO

        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))

        self.num_classes = num_classes

    def _load_image(self, id: int) -> Image.Image:
        path = self.coco.loadImgs(id)[0]["file_name"]
        return Image.open(os.path.join(self.root, path)).convert("RGB")

    def _load_target(self, id: int) -> List[Any]:
        return self.coco.loadAnns(self.coco.getAnnIds(id))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        # Create a counting vector for labels
        label_vector = torch.zeros(self.num_classes)

        if len(target) == 0:
            raise Exception(f"Missing id={id} in root={self.root}")

        # Set the corresponding element for each category_id in the target (Background has been considered as a category)
        for obj in target:
            category_id = obj["category_id"]
            label_vector[category_id] = 1.0

        return image, label_vector

    def __len__(self) -> int:
        return len(self.ids)


transform_train = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


def create_multi_label_classification_client(
    args,
    base_models_path,
    train_ann_files,
    test_ann_files,
    root_train_image_folder,
    root_val_image_folder,
    public_loader,
    lr,
    num_local_epochs,
    device,
    train_func,
    train_args,
):
    ## Args
    num_classes = 91

    model = multi_label_classification_model(
        base_models_path, args.proj_dim, num_classes, args.rank
    )

    local_optmizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Create dataset
    train_dataset = CocoMultiLabelDataset(
        root=root_train_image_folder,
        annFile=train_ann_files,
        transform=transform_train,
        num_classes=num_classes,
    )
    test_dataset = CocoMultiLabelDataset(
        root=root_val_image_folder,
        annFile=test_ann_files,
        transform=transform_test,
        num_classes=num_classes,
    )

    # Create dataloader
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=False,
    )

    return Client(
        task=f"multi_label_clasification_n_label_{num_classes}",
        model=model,
        local_optimizer=local_optmizer,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        train_loader=train_loader,
        test_loader=test_loader,
        public_loader=public_loader,
        local_train_func=train_func(
            num_local_epochs,
            model,
            local_optmizer,
            train_loader,
            public_loader,
            device,
            train_args,
        ),
        local_eval_fun=multi_label_classification_eval_function(
            model, test_loader, train_loader, device
        ),
        train_args=train_args,
        num_classes=num_classes,
        server_lr=lr,
    )
