from transformers.models.llama.modeling_llama import (
    LlamaModel,
    LlamaForCausalLM,
    LlamaConfig,
    LlamaDecoderLayer,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    LlamaMLP,
)
from torch import nn
import torch
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.processing_utils import Unpack
import random
from transformers.activations import ACT2FN


class Router(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(Router, self).__init__()
        self.fc = nn.Linear(
            input_dim,
            num_experts,
            dtype=torch.bfloat16,
        )  # maps input to expert logits

        self.use_soft_counts = False

    def forward(self, x):
        print("x shape in router in is: ", x.shape)
        # x: shape (batch size, seq_len, hidden size)
        logits = self.fc(x)  # shape (batch size, seq_len, num_experts)
        print("logits shape: ", logits.shape)
        routing_weights = nn.functional.softmax(
            logits, dim=-1
        )  # normalize to obtain probabilities, shape (batch size, seq_len, num_experts)

        if self.use_soft_counts:
            # Option 1: Soft count – sum the probabilities across tokens.
            # Each value is the "soft" count of tokens routing to an expert.
            counts = routing_weights.sum(dim=1)  # (batch_size, num_experts)
        else:
            # Option 2: Hard count – count the number of tokens that choose each expert.
            # Note: torch.argmax is non-differentiable.
            preferred_expert = torch.argmax(
                routing_weights, dim=-1
            )  # (batch_size, seq_len)
            one_hot = nn.functional.one_hot(
                preferred_expert, num_classes=routing_weights.shape[-1]
            ).float()  # (batch_size, seq_len, num_experts)
            counts = one_hot.sum(dim=1)  # (batch_size, num_experts)
            print(counts.shape)
            print(counts)
            exit()
        return routing_weights


class LoopedLlamaMLP(nn.Module):
    """
    Adds option to have different input and output shapes to usual mlp
    """

    def __init__(self, hidden_size, intermediate_size, hidden_act, output_size=None):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.output_size = self.hidden_size if output_size is None else output_size
        self.down_proj = nn.Linear(self.intermediate_size, self.output_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class LoopedLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)

    def forward(self, *args, cache_layer_idx=None, **kwargs):
        if cache_layer_idx is not None:
            saved_cache_idx = self.self_attn.layer_idx
            self.self_attn.layer_idx = cache_layer_idx
        out = super().forward(*args, **kwargs)
        if cache_layer_idx is not None:
            self.self_attn.layer_idx = saved_cache_idx
        return out


class LoopLinearAdapter(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.skip_connection_net = nn.Linear(
            self.config.hidden_size * 2,
            self.config.hidden_size,
            bias=False,
            dtype=torch.bfloat16,
        )
        # nn.init.normal_(self.skip_connection_net.weight, mean=0.0, std=1e-5)
        nn.init.zeros_(self.skip_connection_net.weight)
        self.norm = nn.RMSNorm(config.hidden_size).to(dtype=torch.bfloat16)

    def forward(self, x1, x2):
        combined_inputs = torch.cat([x1, x2], dim=-1)
        linear_x = self.skip_connection_net(combined_inputs)
        return linear_x + x1


class LoopMLPAdapter(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.skip_connection_net = LoopedLlamaMLP(
            self.config.hidden_size * 2,
            self.config.intermediate_size,
            self.config.hidden_act,
            self.config.hidden_size,
        ).to(dtype=torch.bfloat16)
        self.skip_connection_net.apply(self.initialize_weights_small)
        self.norm = nn.RMSNorm(config.hidden_size).to(dtype=torch.bfloat16)

    def initialize_weights_small(self, module):
        if isinstance(module, nn.Linear):
            # nn.init.normal_(module.weight, mean=0.0, std=0)
            nn.init.zeros_(module.weight)

    def forward(self, x1, x2):
        combined_inputs = torch.cat([x1, x2], dim=-1)
        linear_x = self.skip_connection_net(combined_inputs)
        return linear_x + x1


class LoopAddAdapter(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        # self.norm = nn.LayerNorm(config.hidden_size).to(dtype=torch.bfloat16)
        self.norm = nn.RMSNorm(config.hidden_size).to(dtype=torch.bfloat16)

    def forward(self, x1, x2):
        return self.norm(x1 + x2)


class LoopConfig:
    num_rec = 2
    start_index = 4
    block_size = 2
    coda_size = None,
    huginn_competitor = True,
    use_adapter = "none"
    randomize_forward = -1.0
    throttle = None
    curriculum = False
    remove_layers = "none"
    use_router = False
    skip_connection = "none"
    train_rec_only = 0
    only_adapter = False
    use_end_cache_only = False
    skip_connection_after_first_loop = True
    """
    remove_layers:
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
    none
    [0, 1, 2, 3, 4]
    [5, 6]
    [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
    after
    [0, 1, 2, 3, 4]
    [5, 6]
    [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
    before
    [0, 1, 2]
    [5, 6]
    [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
    """

    def __init__(self, config: dict = None):
        if config:
            for key, value in config.items():
                if hasattr(self, key):
                    setattr(self, key, value)
                else:
                    raise ValueError(f"Unknown config key: {key}")
        if self.huginn_competitor:
            assert (self.coda_size is not None), "coda size cannot be none for Huginn competitor"

    def __repr__(self):
        return str(
            {key: getattr(self, key) for key in vars(self) if not key.startswith("__")}
        )


class LoopedLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = LoopedLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def rec_post_init(self, args, extra_tensors={}):
        self.model.rec_post_init(args, extra_tensors)

    def set_num_rec(self, new_num_rec):
        self.model.loop_config.num_rec = new_num_rec

    def get_num_rec(self):
        return self.model.loop_config.num_rec

    def get_latest_rep(self):
        return self.model.latest_rep


class LoopedLlamaModel(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                LoopedLlamaDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_split(self, num_rec, remove_layers, i, j):
        base_list = list(range(self.config.num_hidden_layers))
        rec_layers = base_list[i : i + j]
        if remove_layers == "both":
            assert (
                num_rec % 2 == 1
            ), f"num_rec {num_rec} not divisible by 2, cannot use remove_layers_either_side"
            num_rec_here = (num_rec // 2) + 1
            prelude = base_list[: i - (j * (num_rec_here - 1))]
            coda = base_list[i + (j * num_rec_here) :]
        else:
            if remove_layers == "before":
                assert i - (j * (num_rec - 1)) > 0, "remove_layers_before is too large"
                prelude = base_list[: i - (j * (num_rec - 1))]
            else:
                prelude = base_list[:i]
            if remove_layers == "after":
                coda = base_list[i + (j * num_rec) :]
            else:
                coda = base_list[i + j :]
        return prelude, rec_layers, coda

    def rec_post_init(self, args, extra_tensors):
        self.loop_config = LoopConfig(args)

        if self.loop_config.use_adapter == "linear":
            self.rec_adapter = nn.Linear(
                self.config.hidden_size,
                self.config.hidden_size,
                bias=False,
                dtype=torch.bfloat16,
            )
            # could also do: LlamaMLP(config)
        elif self.loop_config.use_adapter == "mlp":
            self.rec_adapter = LlamaMLP(self.config).to(dtype=torch.bfloat16)
        else:
            self.rec_adapter = None

        if "rec_adapter" in extra_tensors:
            # assert False, "rec adapter loading not tested"
            self.rec_adapter.load_state_dict(extra_tensors["rec_adapter"])
            self.rec_adapter.to(self.device)

        if self.loop_config.skip_connection == "linear":
            # self.skip_connection_net = nn.Linear(
            #     self.config.hidden_size * 2,
            #     self.config.hidden_size,
            #     bias=False,
            #     dtype=torch.bfloat16,
            # )
            self.skip_connection_net = LoopLinearAdapter(self.config)
        elif self.loop_config.skip_connection == "ffn":
            # self.skip_connection_net = LoopedLlamaMLP(
            #     self.config.hidden_size * 2,
            #     self.config.intermediate_size,
            #     self.config.hidden_act,
            #     self.config.hidden_size,
            # ).to(dtype=torch.bfloat16)
            self.skip_connection_net = LoopMLPAdapter(self.config)
        elif self.loop_config.skip_connection == "add":
            self.skip_connection_net = LoopAddAdapter(self.config)
        # elif self.loop_config.skip_connection == "rtc": # residual thinking connection? https://arxiv.org/abs/2502.13842

        if "skip_connection_net" in extra_tensors:
            self.skip_connection_net.load_state_dict(
                extra_tensors["skip_connection_net"]
            )
        if (self.loop_config.skip_connection != "none") or ("skip_connection_net" in extra_tensors):
            self.skip_connection_net.to(self.device)
            # print(self.skip_connection_net.skip_connection_net.weight)
            # exit()

        i, j = self.loop_config.start_index, self.loop_config.block_size
        prelude_ind, rec_layers_ind, coda_ind = self.get_split(
            self.loop_config.num_rec, self.loop_config.remove_layers, i, j
        )
        if self.loop_config.coda_size is not None:
            coda_ind = coda_ind[:self.loop_config.coda_size]

        print(f"regular model: {list(range(self.config.num_hidden_layers))}")
        print(f"prelude: {prelude_ind}")
        print(f"rec block: {rec_layers_ind}")
        print(f"coda: {coda_ind}")
        self.prelude = nn.ModuleList([self.layers[idx] for idx in prelude_ind])
        rec_layers = [self.layers[idx] for idx in rec_layers_ind]

        assert not (
            (self.rec_adapter is not None) and (self.loop_config.use_router)
        ), "Can't have a router and adapter"
        if self.rec_adapter is not None:
            rec_layers += [self.rec_adapter]
        elif self.loop_config.use_router:
            self.router = Router(
                self.config.hidden_size, 2
            )  # 2 options: repeat, continue, TODO add a 3rd skip next layer?
            # TODO ablate 128 ..
            rec_layers += [self.router]

        self.rec_block = nn.ModuleList(rec_layers)
        self.coda = nn.ModuleList([self.layers[idx] for idx in coda_ind])

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You must specify exactly one of input_ids or inputs_embeds"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
            output_attentions,
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        hidden_states, all_hidden_states, all_self_attns = self.prelude_rec_coda(
            output_hidden_states,
            all_hidden_states,
            hidden_states,
            causal_mask,
            position_ids,
            past_key_values,
            output_attentions,
            use_cache,
            cache_position,
            position_embeddings,
            flash_attn_kwargs,
            all_self_attns,
        )

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # torch.save(hidden_states.detach(), "llama_tesnor.pt")
        # assert False, "done Llama"
        hidden_states = self.norm(hidden_states)
        # torch.save(hidden_states.detach(), "llama_tesnor.pt")
        # assert False, "done Llama"

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def prelude_rec_coda(
        self,
        output_hidden_states,
        all_hidden_states,
        hidden_states,
        causal_mask,
        position_ids,
        past_key_values,
        output_attentions,
        use_cache,
        cache_position,
        position_embeddings,
        flash_attn_kwargs,
        all_self_attns,
    ):
        og_input_to_block = None
        layer_count = 0
        for layers, reps, is_rec_block, is_pre in [
            (self.prelude, 1, False, True),
            (self.rec_block, self.loop_config.num_rec, True, False),
            (self.coda, 1, False, False),
        ]:
            keep_looping = True
            # for layers, reps in [(self.layers, 1)]:
            for rep in range(reps):
                this_use_cache = use_cache
                pkv_for_layer = past_key_values
                this_layer_count = None
                # if self.loop_config.use_end_cache_only:
                #     # this_use_cache = use_cache if rep == reps - 1 else False
                #     # pkv_for_layer = past_key_values if this_use_cache else None
                #     this_layer_count = layer_count
                # else:
                #     # this_use_cache = use_cache
                #     # pkv_for_layer = past_key_values
                #     this_layer_count = None

                # if is_rec_block:
                #     # torch.save(hidden_states.detach(), "llama_tesnor.pt")
                #     # assert False, "done Llama"
                #     if not keep_looping:
                #         continue
                #     if og_input_to_block is None:
                #         og_input_to_block = (
                #             hidden_states.clone()
                #         )  # store the hidden state for skip connections

                for idx, decoder_layer in enumerate(layers):
                    # if self.loop_config.use_end_cache_only and (
                    #     this_layer_count is not None
                    # ):
                    #     this_layer_count = layer_count
                    #     layer_count += 1

                    if output_hidden_states:
                        all_hidden_states += (hidden_states.clone().detach(),)

                    # if isinstance(decoder_layer, (nn.Linear, LlamaMLP)):  # adapter
                    #     if rep == self.loop_config.num_rec - 1:
                    #         # don't apply adapter on last rep
                    #         continue
                    #     before_adapter = hidden_states.clone()
                    #     hidden_states = decoder_layer(hidden_states)
                    # elif isinstance(decoder_layer, Router):
                    #     print("on the adapter side")
                    #     hidden_states = decoder_layer(hidden_states)
                    #     print("hidden states out is: ", hidden_states.shape)
                    #     exit()
                    # else:
                    #     if is_rec_block:
                    #         # print(rep, idx)
                    #         if (
                    #             idx
                    #             == 0  # making skip connection only to first layer in block
                    #             and not (
                    #                 rep == 0
                    #                 and idx
                    #                 == 0  # could make this or so that we don't apply it at the start of each block
                    #                 and self.loop_config.skip_connection_after_first_loop
                    #             )
                    #         ):
                    #             # print(rep, idx)
                                # if self.loop_config.skip_connection == "add":
                                #     hidden_states = self.skip_connection_net(
                                #         hidden_states, og_input_to_block
                                #     )
                                #     # hidden_states = hidden_states + og_input_to_block
                                # elif self.loop_config.skip_connection in [
                                #     "ffn",
                                #     "linear",
                                # ]:
                                #     # print("before hidden state: ", hidden_states.shape) # [batch, seq, hidden dim]
                                #     # combined_inputs = torch.cat(
                                #     #     [hidden_states, og_input_to_block], dim=-1
                                #     # )  # concat on hidden dim
                                #     hidden_states = self.skip_connection_net(
                                #         hidden_states, og_input_to_block
                                #     )

                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=pkv_for_layer,
                        output_attentions=output_attentions,
                        use_cache=this_use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                        **flash_attn_kwargs,
                        cache_layer_idx=this_layer_count,
                    )

                    hidden_states = layer_outputs[0]
                    # torch.save(hidden_states.detach(), "llama_tesnor.pt")
                    # assert False, "done Llama"

                    if output_attentions:
                        all_self_attns += (layer_outputs[1],)

                # if is_rec_block and self.training:
                #     self.latest_rep = rep + 1
                #     if random.random() < self.loop_config.randomize_forward:
                #         keep_looping = False
                #         if (self.rec_adapter is not None) and (
                #             rep != self.loop_config.num_rec - 1
                #         ):
                #             hidden_states = before_adapter  # need to remove adapter if exiting here :)

        return hidden_states, all_hidden_states, all_self_attns