import json
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset
import os
from PIL import Image
import torch.nn as nn


class TextEncoder(nn.Module):
    def __init__(self, local_model_path=None):
        if local_model_path is None:
            # Use relative path from release directory
            import os
            BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
            local_model_path = os.path.join(BASE_DIR, "checkpoints/text_encoder/local_model_dir")
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(local_model_path)
        self.model = AutoModel.from_pretrained(local_model_path)
        self.model.train()
        # if torch.cuda.is_available():
        #     self.model = self.model.cuda()
        self.embedding_dim = self.model.config.hidden_size

    def gradient_checkpointing_enable(self):
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

    def _prep_inputs(self, x: torch.Tensor, vocab_size: int) -> torch.Tensor:
        x = x.long()
        x = torch.clamp(x, 0, vocab_size - 1)
        return x

    def validate_text_inputs(self, text_item):
        device = next(self.model.parameters()).device
        vocab_size = self.model.config.vocab_size

        for key in ['input_ids', 'attention_mask', 'token_type_ids']:
            if key in text_item:
                tensor = text_item[key]
                if torch.isnan(tensor).any() or torch.isinf(tensor).any():
                    print(f"Warning: TextEncoder {key} contains NaN/Inf, replacing with zeros")
                    text_item[key] = torch.where(torch.isnan(tensor) | torch.isinf(tensor),
                                                torch.zeros_like(tensor), tensor)

        input_ids = self._prep_inputs(text_item['input_ids'], vocab_size)
        attention_mask = self._prep_inputs(text_item['attention_mask'], 2).clamp(0, 1)
        token_type_ids = self._prep_inputs(text_item['token_type_ids'], 2).clamp(0, 1)

        original_shape = input_ids.shape
        return {
            'input_ids': input_ids.to(device),
            'attention_mask': attention_mask.to(device),
            'token_type_ids': token_type_ids.to(device)
        }, original_shape

    def _pool_sequence(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
            print("Warning: TextEncoder hidden_states contains NaN/Inf, replacing with zeros")
            hidden_states = torch.where(torch.isnan(hidden_states) | torch.isinf(hidden_states),
                                       torch.zeros_like(hidden_states), hidden_states)
            

        cls_vec = hidden_states[:, 0:1, :]  # [B*,1,H]
        mask = attention_mask.unsqueeze(-1).float()  # [B*, L, 1]
        denom = mask.sum(dim=1).clamp(min=1.0)       # [B*,1]
        mean_vec = (hidden_states * mask).sum(dim=1) / denom  # [B*,H]
        
        if torch.isnan(mean_vec).any() or torch.isinf(mean_vec).any():
            print("Warning: TextEncoder pooled output contains NaN/Inf, replacing with small random values")
            mean_vec = torch.where(torch.isnan(mean_vec) | torch.isinf(mean_vec),
                                  torch.randn_like(mean_vec) * 0.01, mean_vec)
        
        return mean_vec

    def forward(self, texts=None, text_item=None, batch_size=32, max_length=512, return_dict=True, preprocessed=False):
        device = next(self.model.parameters()).device

        if preprocessed:
            if text_item is None:
                raise ValueError("When preprocessed=True, 'text_inputs' must be provided.")

            text_item, original_shape = self.validate_text_inputs(text_item)

            had_multi = False
            B = 0; N = 1; L = 0
            if len(original_shape) == 3:
                B, N, L = original_shape
                had_multi = True
                ids = text_item['input_ids'].view(B * N, L)
                am  = text_item['attention_mask'].view(B * N, L)
                tt  = text_item['token_type_ids'].view(B * N, L)
            elif len(original_shape) == 2:
                B, L = original_shape
                ids = text_item['input_ids']
                am  = text_item['attention_mask']
                tt  = text_item['token_type_ids']
            elif len(original_shape) == 1:
                B, L = 1, original_shape[0]
                ids = text_item['input_ids'].unsqueeze(0)
                am  = text_item['attention_mask'].unsqueeze(0)
                tt  = text_item['token_type_ids'].unsqueeze(0)
            else:
                flat = text_item['input_ids'].view(text_item['input_ids'].size(0), -1)
                B, L = flat.shape
                ids = flat
                am  = text_item['attention_mask'].view(B, L)
                tt  = text_item['token_type_ids'].view(B, L)

            if self.model.training:
                outputs = torch.utils.checkpoint.checkpoint(
                    lambda ids_, tt_, am_: self.model(input_ids=ids_, token_type_ids=tt_, attention_mask=am_, output_hidden_states=True, return_dict=True),
                    ids, tt, am, use_reentrant=False
                )
            else:
                outputs = self.model(input_ids=ids, token_type_ids=tt, attention_mask=am, output_hidden_states=True, return_dict=True)

            last_hidden_state = outputs.last_hidden_state
            pooled = self._pool_sequence(last_hidden_state, am)  # [B* N, H]  [B, H]

            if return_dict:
                return {
                    "last_hidden_state": last_hidden_state,
                    "attention_mask": am,
                    "input_ids": ids,
                }
            else:
                if had_multi:
                    H = pooled.size(-1)
                    pooled = pooled.view(B, N, H).mean(dim=1)  # [B,H]
                return pooled
        else:
            if texts is None:
                raise ValueError("When preprocessed=False, 'texts' must be provided.")

            all_hidden = []
            all_attention_mask = []
            all_input_ids = []

            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i:i + batch_size]
                inputs = self.tokenizer(
                    batch_texts,
                    return_tensors='pt',
                    padding='max_length',
                    truncation=True,
                    max_length=max_length
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                inputs, _ = self.validate_text_inputs(inputs)

                if self.model.training:
                    outputs = torch.utils.checkpoint.checkpoint(
                        lambda ids, tt, am: self.model(input_ids=ids, token_type_ids=tt, attention_mask=am, output_hidden_states=True, return_dict=True),
                        inputs['input_ids'], inputs['token_type_ids'], inputs['attention_mask'], use_reentrant=False
                    )
                else:
                    outputs = self.model(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'], attention_mask=inputs['attention_mask'], output_hidden_states=True, return_dict=True)

                all_hidden.append(outputs.last_hidden_state)
                all_attention_mask.append(inputs['attention_mask'])
                all_input_ids.append(inputs['input_ids'])

                del outputs

            last_hidden_state = torch.cat(all_hidden, dim=0)
            attention_mask = torch.cat(all_attention_mask, dim=0)
            input_ids = torch.cat(all_input_ids, dim=0)

            if return_dict:
                return {
                    "last_hidden_state": last_hidden_state,
                    "attention_mask": attention_mask,
                    "input_ids": input_ids
                }
            else:
                return self._pool_sequence(last_hidden_state, attention_mask)

if __name__ == "__main__":
    # Example usage - use relative path or configure via parameter
    import os
    base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
    local_model_path = os.path.join(base_dir, "checkpoints/text_encoder/local_model_dir")
    text_encoder = TextEncoder(local_model_path)
    
    sample_texts = ["This is a sample text.", "Another example of text encoding."]
    embeddings = text_encoder.encode(sample_texts)
    
    print("Encoded Text Embeddings Shape:", embeddings.shape)  # Should print (number_of_texts, embedding_size)