from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from transformers import AutoTokenizer

from torch import nn, einsum

from trainer.models.base_model import BaseModelConfig

from transformers import CLIPConfig
from typing import Any, Optional, Tuple, Union
import torch

from trainer.models.cross_modeling import Cross_model

import gc

class XCLIPModel(HFCLIPModel):
    def __init__(self, config: CLIPConfig):
        super().__init__(config)
    
    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:

        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # pooled_output = text_outputs[1]
        # text_features = self.text_projection(pooled_output)
        last_hidden_state = text_outputs[0]
        text_features = self.text_projection(last_hidden_state)

        pooled_output = text_outputs[1]
        text_features_EOS = self.text_projection(pooled_output)


        # del last_hidden_state, text_outputs
        # gc.collect()

        return text_features, text_features_EOS

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        
        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # pooled_output = vision_outputs[1]  # pooled_output
        # image_features = self.visual_projection(pooled_output)
        last_hidden_state = vision_outputs[0]
        image_features = self.visual_projection(last_hidden_state)

        return image_features



@dataclass
class ClipModelConfig(BaseModelConfig):
    _target_: str = "trainer.models.clip_model.CLIPModel"
    pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32"


class CLIPModel(nn.Module):
    def __init__(self, ckpt):
        super().__init__()
        self.model = XCLIPModel.from_pretrained(ckpt)
        self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
    
    def get_text_features(self, *args, **kwargs):
        return self.model.get_text_features(*args, **kwargs)

    def get_image_features(self, *args, **kwargs):
        return self.model.get_image_features(*args, **kwargs)

    def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
        outputs = ()

        text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
        outputs += text_EOS,

        image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
        condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024

        sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
        sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
        sim_text_condition = sim_text_condition / sim_text_condition.max()
        mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77

        mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
        bc = int(image_f.shape[0]/2)

        sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
        sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
        outputs += sim0[:,0,:],
        outputs += sim1[:,0,:],

        return outputs

    @property
    def logit_scale(self):
        return self.model.logit_scale

    def save(self, path):
        self.model.save_pretrained(path)