
import torch.nn.functional as F

from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
                         Phi3Model, Phi3Config, Phi3ForCausalLM


from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast

from transformers.generation.utils import GenerateOutput

from transformers.models.phi3.modeling_phi3 import Phi3DecoderLayer,Phi3MLP
from transformers.cache_utils import Cache, DynamicCache
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM

from transformers.cache_utils import Cache
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from typing_extensions import Unpack
import math

class Phi3HyperLoraMLP(Phi3MLP):
    def __init__(self, config):
        super().__init__(config)
        
        self.sevenb=getattr(config, "sevenb", False)
        self.tasks=["Advice","CKD","Dental","Diabetes","Drug_recommendation","Food","heart","Obesity","Skin","Sleep","Symptom2diagnosis"]
        
        self.record_hlora = False     
        self.saved_hlora = None
        hyper_lora_r = config.hyper_lora_r
        hyper_lora_alpha = config.hyper_lora_alpha
        hyper_lora_num = config.hyper_lora_num
        hyper_lora_dropout = config.hyper_lora_dropout
        embed_size = config.embed_size
        compress_dim = config.compress_dim
        hidden_size = config.hidden_size
        self.soft_moe=config.soft_moe
        self.no_hyper=config.no_hyper


        if(self.soft_moe and not self.no_hyper):
            self.up_router = nn.Linear(self.gate_up_proj.in_features, len(self.tasks))
            self.down_router = nn.Linear(self.down_proj.in_features, len(self.tasks))
            self.up_hyper_nets = nn.ModuleDict()
            self.down_hyper_nets = nn.ModuleDict()
            self.phi3_up_lora_layers = nn.ModuleDict()
            self.phi3_down_lora_layers = nn.ModuleDict()
            for task in self.tasks:
                self.up_hyper_nets[task] = Phi3ParameterGenerator(
                embed_size, compress_dim, hidden_size,
                hyper_lora_r, hyper_lora_num,
                self.gate_up_proj.in_features, self.gate_up_proj.out_features
            )
                self.down_hyper_nets[task] = Phi3ParameterGenerator(
                embed_size, compress_dim, hidden_size,
                hyper_lora_r, hyper_lora_num,
                self.down_proj.in_features, self.down_proj.out_features
            )
                self.phi3_up_lora_layers[task]=Phi3HyperLoRALayer(
            in_features=self.gate_up_proj.in_features, 
            out_features=self.gate_up_proj.out_features,
            r=hyper_lora_r,
            lora_alpha=hyper_lora_alpha,
            lora_num=hyper_lora_num,
            lora_dropout=hyper_lora_dropout
        )
                self.phi3_down_lora_layers[task]=Phi3HyperLoRALayer(
            in_features=self.down_proj.in_features, 
            out_features=self.down_proj.out_features,
            r=hyper_lora_r,
            lora_alpha=hyper_lora_alpha,
            lora_num=hyper_lora_num,
            lora_dropout=hyper_lora_dropout
        )       
        else:
        
            self.up_hyper_net=Phi3ParameterGenerator(embed_size, compress_dim, hidden_size, hyper_lora_r, hyper_lora_num, self.gate_up_proj.in_features, self.gate_up_proj.out_features)
            self.down_hyper_net=Phi3ParameterGenerator(embed_size, compress_dim, hidden_size, hyper_lora_r, hyper_lora_num, self.down_proj.in_features, self.down_proj.out_features)
            self.phi3_up_lora_layer=Phi3HyperLoRALayer(
            in_features=self.gate_up_proj.in_features, 
            out_features=self.gate_up_proj.out_features,
            r=hyper_lora_r,
            lora_alpha=hyper_lora_alpha,
            lora_num=hyper_lora_num,
            lora_dropout=hyper_lora_dropout
        )
            self.phi3_down_lora_layer=Phi3HyperLoRALayer(
            in_features=self.down_proj.in_features, 
            out_features=self.down_proj.out_features,
            r=hyper_lora_r,
            lora_alpha=hyper_lora_alpha,
            lora_num=hyper_lora_num,
            lora_dropout=hyper_lora_dropout
        )
        
        


    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        
        
        if self.record_hlora:
            self.saved_hlora = None
        if(self.no_hyper):
            x = hidden_states
            dtype = x.dtype
            up_base = self.gate_up_proj(x)
            gate, up_states = up_base.chunk(2, dim=-1)
            up_states = up_states * self.activation_fn(gate)
            down_base=self.down_proj(up_states)
            return down_base
        else:
            if(self.soft_moe):
                x = hidden_states
                dtype = x.dtype
                num_tasks = len(self.tasks)
                up_base = self.gate_up_proj(x)
                up_vals = []
                for task in self.tasks:
                    up_vals.append(self.phi3_up_lora_layers[task](x))
                up_all = torch.stack(up_vals, dim=-1) 
                
                w_up = F.softmax(self.up_router(x), dim=-1, dtype=torch.float32).to(dtype)
                
                up_states= up_base+(up_all * w_up.unsqueeze(-2)).sum(dim=-1)
                gate, up_states = up_states.chunk(2, dim=-1)
                up_states = up_states * self.activation_fn(gate)
                
                down_base=self.down_proj(up_states)
                down_vals=[]
                for task in self.tasks:
                    down_vals.append(self.phi3_down_lora_layers[task](up_states))
                down_all = torch.stack(down_vals, dim=-1) 
                
                w_down=F.softmax(self.down_router(up_states), dim=-1, dtype=torch.float32).to(dtype)
                
                down_delta_mix=(down_all * w_down.unsqueeze(-2)).sum(dim=-1)
                down_states= down_base+down_delta_mix
                

                return down_states
            else:
                
                up_states = self.gate_up_proj(hidden_states)+self.phi3_up_lora_layer(hidden_states)
                if(self.sevenb):
                    up_states = self.activation_fn(up_states)
                else:
                    gate, up_states = up_states.chunk(2, dim=-1)
                    up_states = up_states * self.activation_fn(gate)

                return self.down_proj(up_states)+self.phi3_down_lora_layer(up_states)



class Phi3AdapterDecoderLayer(Phi3DecoderLayer):
    def __init__(self, config: Phi3Config,block_id):
        super(Phi3AdapterDecoderLayer, self).__init__(config,block_id)
        self.block_id = block_id
        
        self.mlp=Phi3HyperLoraMLP(config)


class Phi3AdapterModel(Phi3Model):
    def __init__(self, config: Phi3Config):
        super(Phi3AdapterModel,self).__init__(config)
        self.soft_moe=config.soft_moe
        self.no_hyper=config.no_hyper
        
        self.tasks=["Advice","CKD","Dental","Diabetes","Drug_recommendation","Food","heart","Obesity","Skin","Sleep","Symptom2diagnosis"]

        self.layers = nn.ModuleList([Phi3AdapterDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
        

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        record_hlora = bool(getattr(self, "_return_hyperlora", False))
        last_idx = 31 
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape[:2]
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        past_key_values_length = 0

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            past_key_values_length = past_key_values.get_usable_length(seq_length)

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
            is_padding_right = attention_mask[:, -1].sum().item() != batch_size
            if is_padding_right:
                raise ValueError(
                    "You are attempting to perform batched generation with padding_side='right'"
                    " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
                    " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
                )

        if self._attn_implementation == "flash_attention_2":
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask,
                (batch_size, seq_length),
                inputs_embeds,
                past_key_values_length,
                sliding_window=self.config.sliding_window,
            )

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        #********************************************
        
        pooling_states = hidden_states.mean(dim=1)
        pooling_states = pooling_states.detach()
        

        for idx in range(32):
            decoder_layer = self.layers[idx]
            if(not self.no_hyper):
                if(self.soft_moe):
                    for task in self.tasks:
                        up_params = decoder_layer.mlp.up_hyper_nets[task](pooling_states)
                        up_lora_layer=decoder_layer.mlp.phi3_up_lora_layers[task]
                        up_lora_layer.clear_lora()
                        up_lora_layer.apply_lora_params(up_params[0], up_params[1])

                        down_params = decoder_layer.mlp.down_hyper_nets[task](pooling_states)
                        down_lora_layer=decoder_layer.mlp.phi3_down_lora_layers[task]
                        down_lora_layer.clear_lora()
                        down_lora_layer.apply_lora_params(down_params[0], down_params[1])
                else:

                    up_params = decoder_layer.mlp.up_hyper_net(pooling_states)
                    up_lora_layer=decoder_layer.mlp.phi3_up_lora_layer
                    up_lora_layer.clear_lora()
                    up_lora_layer.apply_lora_params(up_params[0], up_params[1])

                    down_params = decoder_layer.mlp.down_hyper_net(pooling_states)
                    down_lora_layer=decoder_layer.mlp.phi3_down_lora_layer
                    down_lora_layer.clear_lora()
                    down_lora_layer.apply_lora_params(down_params[0], down_params[1])
        if record_hlora:
            for i, lyr in enumerate(self.layers):
                lyr.mlp.record_hlora = (i == last_idx)
        for idx in range(32):
            decoder_layer=self.layers[idx]
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        if record_hlora:
            rec = getattr(self.layers[last_idx].mlp, "saved_hlora", None)
            print("rec is ",rec,flush=True)
            if rec is not None and "down" in rec and "delta_mix" in rec["down"]:
                
                hyper_token = rec["down"]["delta_mix"]  
                self._last_hlora_token = hyper_token.detach()

                
                if attention_mask is not None and attention_mask.dim() == 2:
                    mask = attention_mask.unsqueeze(-1).to(hyper_token.dtype)  
                    denom = mask.sum(1).clamp_min(1e-6)
                    hyper_sent = (hyper_token * mask).sum(1) / denom            
                else:
                    hyper_sent = hyper_token.mean(1)

                self._last_hlora_sentence = hyper_sent.detach()
            else:
                self._last_hlora_token = None
                self._last_hlora_sentence = None
        
        
        next_cache = None
        if use_cache:
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class LlavaPhiConfig(Phi3Config):
    model_type = "llava_phi"
    def __init__(self,  compress_dim=16, adapter_size=16,layer_num=32, hidden_size=3072, embed_size=3072,
        hyper_lora_r = 16,
        hyper_lora_alpha = 0,
        hyper_lora_num=4,
        hyper_lora_dropout = 0.05,
        soft_moe=False,
        no_hyper=False,
        **kwargs):
        super().__init__(**kwargs)
        self.layer_num = layer_num
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.compress_dim = compress_dim
        self.adapter_size = adapter_size
        self.soft_moe=soft_moe
        self.no_hyper=no_hyper

        self.hyper_lora_r= hyper_lora_r
        self.hyper_lora_alpha = hyper_lora_alpha
        self.hyper_lora_num=hyper_lora_num
        self.hyper_lora_dropout = hyper_lora_dropout


class LlavaPhiModel(LlavaMetaModel, Phi3AdapterModel):
    config_class = LlavaPhiConfig

    def __init__(self, config: Phi3Config):
        super(LlavaPhiModel, self).__init__(config)


class LlavaPhiForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaPhiConfig

    def __init__(self, config):
        super(Phi3ForCausalLM, self).__init__(config)
        self.model = LlavaPhiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[bool]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        return_hyperlora: Optional[bool] = False
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        self.model._return_hyperlora = bool(return_hyperlora)
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs




class Phi3HyperLoRALayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int,
        lora_alpha: int = 1,
        lora_num: int = 4,
        lora_dropout: float=0.0
    ):
        super(Phi3HyperLoRALayer, self).__init__()


        self.in_features=in_features
        self.out_features=out_features
        self.r=r
        self.lora_alpha=lora_alpha
        self.lora_num=lora_num
        self.lora_dropout=nn.Dropout(lora_dropout)

        self.lora_A=None
        self.lora_B_list=[]
        self.router  = nn.Linear(in_features, lora_num)
        
        self.scaling = self.lora_alpha / self.r
        
    def clear_lora(self):    
        self.lora_A=None
        self.lora_B_list=[]

    def apply_lora_params(self, a,b):
        batch_size = a.shape[0]
        self.lora_A = a.view(batch_size, self.in_features, self.r)
        for bflat in b:
            self.lora_B_list.append(bflat.view(batch_size, self.r, self.out_features))
    
    def forward(self, x):
        x_a = self.lora_dropout(x)
        x_b = x_a @ self.lora_A
        gates =torch.softmax(self.router(x_a), dim=-1) 
        deltas = []
        for Bi in self.lora_B_list:   
            deltas.append(x_b @ Bi)
        
        deltas_a = torch.stack(deltas, dim=2)

        x_c = torch.sum(deltas_a * gates.unsqueeze(-1), dim=2)
        x_d = x_c*self.scaling
        return x_d





def hyperfanin_init_weight(linear_layer, hypernet_in, mainnet_in):
    bound = 1e-3 * math.sqrt(3 / (hypernet_in * mainnet_in))
    nn.init.uniform_(linear_layer.weight, -bound, bound)
    nn.init.constant_(linear_layer.bias, 0.0)


def hyperfanin_init_bias(linear_layer, hypernet_in):
    bound = 1e-3 * math.sqrt(3 / (hypernet_in))
    nn.init.uniform_(linear_layer.weight, -bound, bound)
    nn.init.constant_(linear_layer.bias, 0.0)


class SimpleGenerator(nn.Module):
    def __init__(self, input_dim, compress_dim, hyper_lora_r, hyper_lora_num, proj_in_dim,proj_out_dim, gate_temp=1.0):
        super(SimpleGenerator, self).__init__()
        
        self.input_dim = input_dim
        self.compress_dim = compress_dim
        self.proj_in_dim=proj_in_dim
        self.proj_out_dim=proj_out_dim
        self.hyper_lora_r=hyper_lora_r
        self.hyper_lora_num=hyper_lora_num
        self.gate_temp = gate_temp  


        self.linear1 = nn.Linear(self.input_dim, self.compress_dim)
        self.activation_fn = nn.ReLU()
        self.norm = Phi3RMSNorm(self.input_dim)
        
        
        self.weight_down = nn.Linear(self.compress_dim, self.proj_in_dim * self.hyper_lora_r)
        self.weight_up = nn.ModuleList([
            nn.Linear(self.compress_dim,
                      self.proj_out_dim * self.hyper_lora_r)
            for _ in range(self.hyper_lora_num)
        ])

        hyperfanin_init_weight(self.linear1, self.input_dim, self.compress_dim)

        hyperfanin_init_weight(self.weight_down, self.compress_dim, self.proj_in_dim * self.hyper_lora_r)
        for layer in self.weight_up:
            hyperfanin_init_weight(layer, self.compress_dim,
                                   self.proj_out_dim * self.hyper_lora_r)
         
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        wA_flat = self.weight_down(x)  

        expert_outs = [layer(x) for layer in self.weight_up]           
           

        return wA_flat, expert_outs
        


class Phi3ParameterGenerator(nn.Module):
    def __init__(self, embed_size, compress_dim, hidden_size, hyper_lora_r,hyper_lora_num, proj_in_dim, proj_out_dim):
        super(Phi3ParameterGenerator, self).__init__()
        self.embed_size = embed_size
        self.compress_dim = compress_dim
        self.hidden_size = hidden_size
        self.hyper_lora_r = hyper_lora_r
        self.hyper_lora_num = hyper_lora_num
        self.proj_in_dim=proj_in_dim
        self.proj_out_dim=proj_out_dim

        
        self.decoder = SimpleGenerator(self.embed_size , self.compress_dim,self.hyper_lora_r, self.hyper_lora_num, \
            self.proj_in_dim, self.proj_out_dim)

    def forward(self, hidden_inputs):
        params=self.decoder(hidden_inputs)
        return params


AutoConfig.register("llava_phi", LlavaPhiConfig)
AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM)
