import os
from typing import Any, Callable, List, Literal, Optional, Tuple, Union

import numpy as np
from sklearn.metrics import f1_score, hamming_loss
import torch
from torch import nn
from torchvision import datasets, transforms
import pickle
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import ViTForImageClassification
from transformers.modeling_outputs import (
    ImageClassifierOutput,
)
from torch.nn import CrossEntropyLoss
from transformers.models.vit.configuration_vit import ViTConfig
from peft import LoraConfig, get_peft_model


from clients.base import Client
from utils.model_ops import print_trainable_parameters


class ModifiedVitForImageClassification(ViTForImageClassification):
    # TODO: Remove unnessary parts

    def __init__(self, config: ViTConfig, output_size=256) -> None:
        super().__init__(config)

        # Projection head
        self.projection_layer = nn.Linear(config.hidden_size, output_size)
        self.classifier_layer1 = nn.Linear(config.hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        projection: Optional[bool] = False,
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        if projection:
            logits = self.projection_layer(sequence_output[:, 0, :])

        else:
            x = self.classifier_layer1(sequence_output[:, 0, :])
            x = self.relu(x)
            logits = self.classifier(x)

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def image_classification_model(base_path: str, proj_dim, num_classes, lora_rank):
    model = ModifiedVitForImageClassification.from_pretrained(base_path)

    model.classifier_layer1 = nn.Linear(model.config.hidden_size, proj_dim)
    model.projection_layer = nn.Linear(model.config.hidden_size, proj_dim)
    model.classifier = nn.Linear(proj_dim, num_classes)

    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=1,
        target_modules=["query", "value", "key"],
        bias="none",
    )
    model = get_peft_model(model, lora_config)

    for layer_name, param in model.named_parameters():
        if ("classifier" in layer_name) or ("projection_layer" in layer_name):
            param.requires_grad = True

    print_trainable_parameters(model)

    return model


def image_classification_eval_function(model, test_loader, train_loader, device):
    def eval_function(evaluation_set: Literal["train", "test"]):
        if evaluation_set == "test":
            loader = test_loader
        elif evaluation_set == "train":
            loader = train_loader
        else:
            raise NotImplementedError

        model.eval()
        model.to(device)
        correct = 0
        total = 0
        all_loss = []
        for idx, (pixel_values, labels) in enumerate(loader):
            with torch.no_grad():
                # get the inputs;
                pixel_values = pixel_values.to(device)
                labels = labels.to(device)

                # forward
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss, logits = outputs.loss, outputs.logits

                _, predicted = torch.max(logits, 1)

                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                all_loss.append(loss.cpu().numpy())

        all_loss = np.array(all_loss)
        avg_loss = np.mean(all_loss)
        accuracy = correct / total
        accuracy = accuracy
        model.to("cpu")

        return {
            "accuracy": np.round(accuracy, 3),
            "avg_loss": np.round(avg_loss, 3),
        }

    return eval_function


class CifarClassificationDataset(Dataset):
    def __init__(
        self,
        fila_path: str,
        transform: Optional[Callable] = None,
    ) -> None:
        """
        Args:
            fila_path: str: Path to the .pkl file.
            transform (callable, optional): A function/transform to apply to the images.
        """
        with open(fila_path, "rb") as file:
            data = pickle.load(file)
            self.images = data["data"]
            self.labels = data["labels"]
        self.transform = transform

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image = self.images[idx]
        label = self.labels[idx]

        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label)

    def __len__(self) -> int:
        return len(self.labels)


transform_train = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize((224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


def create_image_classification_client(
    args,
    base_models_path,
    train_files,
    test_files,
    public_loader,
    lr,
    num_local_epochs,
    num_classes,
    device,
    train_func,
    train_args,
):
    ## Args

    model = image_classification_model(
        base_models_path, args.proj_dim, num_classes, args.rank
    )

    local_optmizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Create dataset
    train_dataset = CifarClassificationDataset(
        fila_path=train_files,
        transform=transform_train,
    )
    test_dataset = CifarClassificationDataset(
        fila_path=test_files,
        transform=transform_test,
    )

    # Create dataloader
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        num_workers=1,
        drop_last=False,
        pin_memory=True,
        shuffle=False,
    )

    return Client(
        task=f"image_clasification_n_label_{num_classes}",
        model=model,
        local_optimizer=local_optmizer,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        train_loader=train_loader,
        test_loader=test_loader,
        public_loader=public_loader,
        local_train_func=train_func(
            num_local_epochs,
            model,
            local_optmizer,
            train_loader,
            public_loader,
            device,
            train_args,
        ),
        local_eval_fun=image_classification_eval_function(
            model, test_loader, train_loader, device
        ),
        train_args=train_args,
        num_classes=num_classes,
        server_lr=lr,
    )
