import torch
from torch import nn
from pathlib import Path
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from safetensors.torch import load_file, save_file
from huggingface_hub import (
    list_repo_files,
    snapshot_download,
    hf_hub_download,
)
from transformers.modeling_utils import load_sharded_checkpoint
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.integrations.flash_attention import flash_attention_forward
from transformers.modeling_outputs import CausalLMOutput

from flash_attn.flash_attn_interface import flash_attn_func

from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=torch.bfloat16)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=torch.bfloat16)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=torch.bfloat16)

    def forward(self, x):
        # return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        # self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim ** -0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True
        self.q_proj = nn.Linear(config.hidden_size,
            config.num_attention_heads * self.head_dim,
            bias=True,
            dtype=torch.bfloat16,
        )
        self.k_proj = nn.Linear(config.hidden_size,
            config.num_key_value_heads * self.head_dim,
            bias=True,
            dtype=torch.bfloat16,
        )
        self.v_proj = nn.Linear(config.hidden_size,
            config.num_key_value_heads * self.head_dim,
            bias=True,
            dtype=torch.bfloat16,
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            dtype=torch.bfloat16,
        )
    
    def forward(self, hidden_states, position_embeddings, attention_mask):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query, key = liger_rotary_pos_emb(query, key, cos, sin)

        if getattr(self.config, "use_sliding_window", False) \
        and self.layer_idx >= self.config.max_window_layers:
            window = self.config.sliding_window
        else:
            window = None

        attn_output, attn_weights = flash_attention_forward(
            self,
            query,
            key,
            value,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=window,  # main diff with Llama
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output


from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm


class Layer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = Attention(config)
        self.mlp = MLP(config)
        self.input_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    
    def forward(self, hidden_states, attention_mask, position_ids, position_embeddings):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states).to(torch.bfloat16)

        # Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            # position_ids=position_ids,
            position_embeddings=position_embeddings,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states).to(torch.bfloat16)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

class RotaryEmbedding(nn.Module):
    def __init__(self, config, device=None):
        super().__init__()
        head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
        inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float32) / head_dim))
        self.register_buffer("inv_freq",  inv_freq)
    
    @torch.no_grad()
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

class Transformer(nn.Module):
    def __init__(self, config, device=None):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id, dtype=torch.bfloat16)
        self.layers = nn.ModuleList(
            [Layer(config) for _ in range(config.num_hidden_layers)]
        )
        self.norm = LigerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = RotaryEmbedding(config, device)
    
    def forward(self, input_ids, attention_mask):
        bs, seqlen = input_ids.shape
        hidden_states = self.embed_tokens(input_ids)

        position_ids = torch.arange(0, seqlen, device=input_ids.device).unsqueeze(0)
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            hidden_states = checkpoint(
                decoder_layer,
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                position_embeddings=position_embeddings,
                use_reentrant=False,
            )
        
        hidden_states = self.norm(hidden_states)
        return hidden_states


class Model(nn.Module):
    def __init__(self, config, device=None):
        super().__init__()
        self.device = device
        self.config = config
        self.model = Transformer(config, device)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=torch.bfloat16)

        self.tie_weights()  # tie weights if necessary
        self.to(device)

    def tie_weights(self):
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self, input_ids, attention_mask=None, compute_loss=False):
        if attention_mask is None:
            attention_mask = torch.ones(input_ids.shape, device=input_ids.device)

        outputs = self.model(
            input_ids,
            attention_mask,
        )

        if compute_loss:
            logits = None
            bs, seqlen, hidden_size = outputs.shape
            labels = torch.nn.functional.pad(input_ids, (0, 1), value=-100)
            shift_labels = labels[..., 1:].contiguous()
            shift_attn_mask = torch.nn.functional.pad(attention_mask, (0, 1), value=0)[..., 1:]
            shift_labels = shift_labels.masked_fill(shift_attn_mask == 0, -100)
            loss = liger_fused_linear_cross_entropy(outputs.reshape(-1, hidden_size), self.lm_head.weight, shift_labels.view(-1))
        else:
            logits = self.lm_head(outputs.bfloat16())
            loss = None

        return CausalLMOutput(
            loss=loss,
            logits=logits,
        )
    
    @classmethod
    def from_pretrained(cls, model_id, config=None, device='cuda'):
        if config is None:
            config = AutoConfig.from_pretrained(model_id)
        model = cls(config, device)

        files = list_repo_files(model_id)
        if any(f.endswith(".safetensors.index.json") for f in files):
            ckpt_dir = snapshot_download(model_id, allow_patterns=["*.safetensors", "*.json"])
            load_sharded_checkpoint(model, ckpt_dir, strict=False, prefer_safe=True)
        else:
            sd_path = hf_hub_download(model_id, "model.safetensors")
            state = load_file(sd_path, device=device)
            model.load_state_dict(state, strict=False)
            del state
            import gc
            gc.collect()
            torch.cuda.empty_cache()

        return model
    
    def save_pretrained(self, output_dir: str, safe_filename: str = "model.safetensors"):
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        self.config.save_pretrained(output_dir)
        state_dict_cpu = {k: v.detach().to("cpu") for k, v in self.state_dict().items()}
        save_file(state_dict_cpu, str(output_dir / safe_filename), metadata={"format": "pt"})



if __name__ == '__main__':
    device = torch.device('cuda')

    model_name = 'Qwen/Qwen2.5-1.5B'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto', attn_implementation='flash_attention_2')
    config = AutoConfig.from_pretrained(model_name)
    model.to(device)

    # weights = model.state_dict()
    # mymodel = Model(config)
    # mymodel.to(device)
    # print(mymodel.load_state_dict(weights, strict=False))

    mymodel = Model.from_pretrained(model_name)
    mymodel.to(device)

    tokens = tokenizer('hello world', return_tensors='pt').to(device)

    assert torch.allclose(model.model.embed_tokens(tokens['input_ids']), mymodel.model.embed_tokens(tokens['input_ids']))

    hidden_states = torch.rand(1, 2, 1536, device=device).bfloat16()
    position_ids = torch.arange(0, 2, device=device).unsqueeze(0)

    assert torch.allclose(model.model.rotary_emb.inv_freq, mymodel.model.rotary_emb.inv_freq)

    orig_position_emb = model.model.rotary_emb(hidden_states, position_ids)
    my_position_emb = mymodel.model.rotary_emb(hidden_states, position_ids)
    assert torch.allclose(orig_position_emb[0], my_position_emb[0])
    assert torch.allclose(orig_position_emb[1], my_position_emb[1])

    assert torch.allclose(model.model.layers[0].mlp(hidden_states), mymodel.model.layers[0].mlp(hidden_states))

    assert torch.allclose(
        model.model.layers[0].self_attn(
            hidden_states=hidden_states,
            position_embeddings=orig_position_emb,
            attention_mask=tokens['attention_mask'].bool(),
        )[0],
        mymodel.model.layers[0].self_attn(
            hidden_states=hidden_states,
            position_embeddings=orig_position_emb,
            attention_mask=tokens['attention_mask'],
        ),
    )

    for i in range(config.num_hidden_layers):
        assert torch.allclose(
            model.model.layers[i](
                hidden_states=hidden_states,
                position_embeddings=orig_position_emb,
                attention_mask=tokens['attention_mask'].bool(),
                position_ids=position_ids,
            )[0],
            mymodel.model.layers[i](
                hidden_states=hidden_states,
                position_embeddings=orig_position_emb,
                attention_mask=tokens['attention_mask'],
                position_ids=position_ids,
            ),
        )


    assert torch.allclose(
        model(tokens['input_ids'], tokens['attention_mask']).logits,
        mymodel(tokens['input_ids'], tokens['attention_mask']),
        atol=1e-7,
    )

