#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
                         PretrainedConfig, PreTrainedModel

from transformers.modeling_outputs import ImageClassifierOutput
from transformers.generation.utils import GenerateOutput, ModelOutput

from ..llava_arch import LlavaMetaModel, RLlavaMetaForCausalLM

from transformers import logging
logger = logging.get_logger(__name__)


@dataclass
class ClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    # accuracy: Optional[Union[float, torch.FloatTensor]] = None


class RLlavaClassifierConfig(PretrainedConfig):
    model_type = "Rllava_Classifier"

    def __init__(
        self,
        num_labels: int = 2,
        hidden_size: int = 2048,
        classifier_dropout: float = 0.1,
        region_interpolate: str = "upsample",
        region_pooling_method: str = "average",
        region_extra: str = "none",
        log_debug: bool = False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.num_labels = num_labels
        self.hidden_size = hidden_size
        self.classifier_dropout = classifier_dropout
        self.region_interpolate = region_interpolate
        self.region_pooling_method = region_pooling_method
        self.region_extra = region_extra
        self.log_debug = log_debug



class RClassifierModel(LlavaMetaModel, PreTrainedModel):
    config_class = RLlavaClassifierConfig

    def __init__(self, config: RLlavaClassifierConfig):
        super().__init__(config)


class RLlavaClassifier(PreTrainedModel, RLlavaMetaForCausalLM):
    config_class = RLlavaClassifierConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = RClassifierModel(config)
        self.region_interpolate = config.region_interpolate
        self.region_pooling_method = config.region_pooling_method
        self.region_extra = config.region_extra
        classifier_dropout = config.classifier_dropout
        self.dropout = nn.Dropout(classifier_dropout)
        self.score = nn.Linear(config.hidden_size, config.num_labels)
        self.loss_fnc = CrossEntropyLoss()
        if config.log_debug:
            self.debug = True
            self.logs={
                "loss": [],
                "regions": [],
                "idx": [],
            }
        else:
            self.debug = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        labels: Optional[torch.LongTensor] = None,
        images: Optional[torch.FloatTensor] = None,
        sam_masks: Optional[Union[List[torch.BoolTensor], torch.BoolTensor]] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        idx: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, ClassifierOutput]:
        image_features = self.encode_images(images, sam_masks)
        if not all([feature.shape[0]==1 for feature in image_features]):
            logger.warning_once("Some image features have batch size != 1, fixing")
            new_image_features = []
            for feature in image_features:
                if feature.shape[0] > 1:
                    new_image_features.append(feature[0:1])
                    logger.warning_once(f"Image has more than one regions")
                elif feature.shape[0] == 0:
                    new_image_features.append(torch.zeros((1, self.config.hidden_size), dtype=self.score.weight.dtype, device=self.score.weight.device))
                else:
                    new_image_features.append(feature)
            image_features = new_image_features
        image_features = torch.cat(image_features, dim=0)
        image_features = self.dropout(image_features)
        logits = self.score(image_features)

        loss = None
        accuracy = None
        if labels is not None:
            loss = self.loss_fnc(logits, labels)
            predictinon = torch.argmax(logits, dim=-1)
            # accuracy = (predictinon == labels).sum() / labels.numel()

        if not return_dict:
            output = (loss, logits, image_features,) if loss is not None else (logits, image_features)
        else:
            output = ClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=image_features,
            )

        if self.debug:
            self.logs["idx"].append(idx.item())
            self.logs["loss"].append(output.loss.item())
            self.logs["regions"].append([len(s) for s in sam_masks])
        return output


AutoConfig.register("Rllava_Classifier", RLlavaClassifierConfig)
AutoModelForCausalLM.register(RLlavaClassifierConfig, RLlavaClassifier)
