import numpy as np
import os
import json
import random
import requests
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    BitsAndBytesConfig,
    Qwen2VLProcessor,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
    # Qwen2_5_VLCausalLMOutputWithPast,
    Qwen2_5_VLPreTrainedModel,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.nn import CrossEntropyLoss

class BrainCereConfig(PretrainedConfig):
    model_type = "BrainCereModel"
    sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}

    def __init__(
        self,
        paligemma_config: dict | None = None,
        gemma_expert_config: dict | None = None,
        freeze_vision_encoder: bool = True,
        train_expert_only: bool = True,
        attention_implementation: str = "eager",
        **kwargs,
    ):
        # self.freeze_vision_encoder = freeze_vision_encoder
        # self.train_expert_only = train_expert_only
        # self.attention_implementation = attention_implementation

        # if paligemma_config is None:
        #     # Default config from Pi0
        #     self.paligemma_config = CONFIG_MAPPING["paligemma"](
        #         transformers_version="4.48.1",
        #         _vocab_size=257152,
        #         bos_token_id=2,
        #         eos_token_id=1,
        #         hidden_size=2048,
        #         image_token_index=257152,
        #         model_type="paligemma",
        #         pad_token_id=0,
        #         projection_dim=2048,
        #         text_config={
        #             "hidden_activation": "gelu_pytorch_tanh",
        #             "hidden_size": 2048,
        #             "intermediate_size": 16384,
        #             "model_type": "gemma",
        #             "num_attention_heads": 8,
        #             "num_hidden_layers": 18,
        #             "num_image_tokens": 256,
        #             "num_key_value_heads": 1,
        #             "torch_dtype": "float32",
        #             "vocab_size": 257152,
        #         },
        #         vision_config={
        #             "hidden_size": 1152,
        #             "intermediate_size": 4304,
        #             "model_type": "siglip_vision_model",
        #             "num_attention_heads": 16,
        #             "num_hidden_layers": 27,
        #             "num_image_tokens": 256,
        #             "patch_size": 14,
        #             "projection_dim": 2048,
        #             "projector_hidden_act": "gelu_fast",
        #             "torch_dtype": "float32",
        #             "vision_use_head": False,
        #         },
        #     )
        # elif isinstance(self.paligemma_config, dict):
        #     # Override Pi0 default config for PaliGemma
        #     if "model_type" not in gemma_expert_config:
        #         paligemma_config["model_type"] = "paligemma"

        #     cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
        #     self.paligemma_config = cfg_cls(**paligemma_config)

        # if gemma_expert_config is None:
        #     # Default config from Pi0
        #     self.gemma_expert_config = CONFIG_MAPPING["gemma"](
        #         attention_bias=False,
        #         attention_dropout=0.0,
        #         bos_token_id=2,
        #         eos_token_id=1,
        #         head_dim=256,
        #         hidden_act="gelu_pytorch_tanh",
        #         hidden_activation="gelu_pytorch_tanh",
        #         hidden_size=1024,
        #         initializer_range=0.02,
        #         intermediate_size=4096,
        #         max_position_embeddings=8192,
        #         model_type="gemma",
        #         num_attention_heads=8,
        #         num_hidden_layers=18,
        #         num_key_value_heads=1,
        #         pad_token_id=0,
        #         rms_norm_eps=1e-06,
        #         rope_theta=10000.0,
        #         torch_dtype="float32",
        #         transformers_version="4.48.1",
        #         use_cache=True,
        #         vocab_size=257152,
        #     )
        # elif isinstance(self.gemma_expert_config, dict):
        #     # Override Pi0 default config for Gemma Expert
        #     if "model_type" not in gemma_expert_config:
        #         gemma_expert_config["model_type"] = "gemma"

        #     cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
        #     self.gemma_expert_config = cfg_cls(**gemma_expert_config)

        super().__init__(**kwargs)

    def __post_init__(self):
        super().__post_init__()
        if self.train_expert_only and not self.freeze_vision_encoder:
            raise ValueError(
                "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
            )

        if self.attention_implementation not in ["eager", "fa2", "flex"]:
            raise ValueError(
                f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
            )


class BrainCereModel(Qwen2_5_VLPreTrainedModel):
    config_class = PretrainedConfig #Qwen2_5_VLConfig
    # config_class = Qwen2_5_VLConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    # _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_static_cache = False  # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
        
    def __init__(self, model_path, config, model_kwargs):
        super().__init__(config=config)
        # model_kwargs = dict(
        #     revision=model_config.model_revision,
        #     trust_remote_code=model_config.trust_remote_code,
        #     torch_dtype=torch_dtype,
        #     device_map=get_kbit_device_map(),
        #     # quantization_config=bnb_config,
        # )
        self.brain = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, **model_kwargs)
        # import pdb; pdb.set_trace()
        self.visual = self.brain.visual #Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
        self.model =  self.brain.model #Qwen2_5_VLModel(config)
        self.vocab_size = self.brain.config.vocab_size
        self.lm_head = self.brain.lm_head #nn.Linear(self.brain.config.hidden_size, self.brain.config.vocab_size, bias=False)
        self.rope_deltas = None  # cache rope_deltas here
        self.config = self.brain.config

        # Initialize weights and apply final processing
        # self.post_init()
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
        # models = [self.brain.model]
        # import pdb; pdb.set_trace()
        # outputs = self.brain.forward(
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     position_ids=position_ids,
        #     past_key_values=past_key_values,
        #     inputs_embeds=inputs_embeds,
        #     labels=labels,
        #     use_cache=use_cache,
        #     output_attentions=output_attentions,
        #     output_hidden_states=output_hidden_states,
        #     return_dict=return_dict,
        #     pixel_values=pixel_values,
        #     pixel_values_videos=pixel_values_videos,
        #     image_grid_thw=image_grid_thw,
        #     video_grid_thw=video_grid_thw,
        #     rope_deltas=rope_deltas,
        #     cache_position=cache_position,
        #     second_per_grid_ts=second_per_grid_ts,
        # )
        # return outputs

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
                n_image_features = image_embeds.shape[0]
                if n_image_tokens != n_image_features:
                    raise ValueError(
                        f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
                    )

                mask = input_ids == self.config.image_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                image_mask = mask_expanded.to(inputs_embeds.device)

                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
                n_video_features = video_embeds.shape[0]
                if n_video_tokens != n_video_features:
                    raise ValueError(
                        f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
                    )

                mask = input_ids == self.config.video_token_id
                mask_unsqueezed = mask.unsqueeze(-1)
                mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
                video_mask = mask_expanded.to(inputs_embeds.device)

                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            # calculate RoPE index once per generation in the pre-fill stage only
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.brain.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            # then use the prev pre-calculated rope-deltas to get the correct position ids
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:  # otherwise `deltas` is an int `0`
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Upcast to float if we need to compute the loss to avoid potential precision issues
            logits = logits.float()
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Qwen2_5_VLCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )
        