import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from einops import rearrange
# from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# from flash_attn.ops.fused_dense import FusedMLP, FusedDense
from torch.nn.functional import scaled_dot_product_attention
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import OmegaConf

from . import rotary
from .fused_add_dropout_scale import (
    bias_dropout_add_scale_fused_train, 
    bias_dropout_add_scale_fused_inference, 
    get_bias_dropout_add_scale, 
    modulate_fused,
)

from .transformer import (
    EmbeddingLayer,
    LayerNorm,
)

class DDiTBlockWot(nn.Module):

    def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1, use_checkpoint=False):
        super().__init__()
        self.n_heads = n_heads

        self.norm1 = LayerNorm(dim)
        self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.attn_out = nn.Linear(dim, dim, bias=False)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio * dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio * dim, dim, bias=True)
        )
        self.dropout2 = nn.Dropout(dropout)

        self.dropout = dropout


        self.use_checkpoint = use_checkpoint


    def _get_bias_dropout_scale(self):
        return (
            bias_dropout_add_scale_fused_train
            if self.training
            else bias_dropout_add_scale_fused_inference
        )


    def forward(self, x, rotary_cos_sin, seqlens=None):
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, rotary_cos_sin, seqlens)
        else:
            return self._forward(x, rotary_cos_sin, seqlens)


    def _forward(self, x, rotary_cos_sin, seqlens=None):
        batch_size, seq_len = x.shape[0], x.shape[1]

        bias_dropout_scale_fn = self._get_bias_dropout_scale()

        # attention operation
        x_skip = x
        x = self.norm1(x)
        # dtype0 = x.dtype

        qkv = self.attn_qkv(x)
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads)
        with torch.cuda.amp.autocast(enabled=False):
            cos, sin = rotary_cos_sin
            qkv = rotary.apply_rotary_pos_emb(
                qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
            )
        qkv = rearrange(qkv, 'b s three h d -> three b h s d')
        q = qkv[0]
        k = qkv[1]
        v = qkv[2]
        x = scaled_dot_product_attention(q, k, v)
        x = rearrange(x, 'b h s d-> b s (h d)', b=batch_size)

        x = bias_dropout_scale_fn(self.attn_out(x), None, torch.tensor([1.], device=x.device), x_skip, self.dropout)

        # mlp operation
        x = bias_dropout_scale_fn(self.mlp(self.norm2(x)), None, torch.tensor([1.], device=x.device), x, self.dropout)
        return x



class DDitFinalLayerWot(nn.Module):
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()


    def forward(self, x):
        x = self.norm_final(x)
        x = self.linear(x)
        return x


class SEDDWot(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        # hack to make loading in configs easier
        if type(config) == dict:
            config = OmegaConf.create(config)

        self.config = config

        self.absorb = config.graph.type == "absorb"
        vocab_size = config.tokens + (1 if self.absorb else 0)

        self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size)
        self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads)

        try:
            # The config from Huggingface has no config.model.use_checkpoint
            use_checkpoint = config.model.use_checkpoint
        except:
            use_checkpoint = False

        self.blocks = nn.ModuleList([
            DDiTBlockWot(config.model.hidden_size, config.model.n_heads, dropout=config.model.dropout,
                      use_checkpoint=use_checkpoint) for _ in range(config.model.n_blocks)
        ])

        self.output_layer = DDitFinalLayerWot(config.model.hidden_size, vocab_size)
        self.scale_by_sigma = config.model.scale_by_sigma

    
    def _get_bias_dropout_scale(self):
        return (
            bias_dropout_add_scale_fused_train
            if self.training
            else bias_dropout_add_scale_fused_inference
        )


    def forward(self, indices, sigma):

        x = self.vocab_embed(indices)

        rotary_cos_sin = self.rotary_emb(x)
        try:
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)
        except:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)

        if self.scale_by_sigma:
            assert self.absorb, "Haven't configured this to work."
            esigm1_log = torch.where(sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(x.dtype)[:, None, None]
            x = x - esigm1_log - np.log(x.shape[-1] - 1)# this will be approximately averaged at 0
            
        x = torch.scatter(x, -1, indices[..., None], torch.zeros_like(x[..., :1]))

        return x


class SEDDWotSM(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        # hack to make loading in configs easier
        if type(config) == dict:
            config = OmegaConf.create(config)

        self.config = config

        self.absorb = config.graph.type == "absorb"
        vocab_size = config.tokens + (1 if self.absorb else 0)

        self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size)
        self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads)

        try:
            # The config from Huggingface has no config.model.use_checkpoint
            use_checkpoint = config.model.use_checkpoint
        except:
            use_checkpoint = False

        self.blocks = nn.ModuleList([
            DDiTBlockWot(config.model.hidden_size, config.model.n_heads, dropout=config.model.dropout,
                      use_checkpoint=use_checkpoint) for _ in range(config.model.n_blocks)
        ])

        self.output_layer = DDitFinalLayerWot(config.model.hidden_size, vocab_size)
        self.scale_by_sigma = config.model.scale_by_sigma

    
    def _get_bias_dropout_scale(self):
        return (
            bias_dropout_add_scale_fused_train
            if self.training
            else bias_dropout_add_scale_fused_inference
        )


    def forward(self, indices, sigma):

        x = self.vocab_embed(indices)

        rotary_cos_sin = self.rotary_emb(x)
        try:
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)
        except:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)

        if self.scale_by_sigma:
            assert self.absorb, "Haven't configured this to work."
            x[:,:,:-1] = x[:,:,:-1].log_softmax(dim=-1) 
            esigm1_log = torch.where(sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(x.dtype)[:, None, None]
            x = x - esigm1_log 
            
        x = torch.scatter(x, -1, indices[..., None], torch.zeros_like(x[..., :1]))

        return x
    
    def get_log_condition(self,indices):
        x = self.vocab_embed(indices)

        rotary_cos_sin = self.rotary_emb(x)
        try:
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)
        except:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)

        if self.scale_by_sigma:
            assert self.absorb, "Haven't configured this to work."
            x[:,:,:-1] = x[:,:,:-1].log_softmax(dim=-1) 
            
        x = torch.scatter(x, -1, indices[..., None], torch.zeros_like(x[..., :1]))

        return x
    
    def get_condition(self,indices):
        x = self.vocab_embed(indices)

        rotary_cos_sin = self.rotary_emb(x)
        try:
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)
        except:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                for i in range(len(self.blocks)):
                    x = self.blocks[i](x, rotary_cos_sin, seqlens=None)

                x = self.output_layer(x)

        if self.scale_by_sigma:
            assert self.absorb, "Haven't configured this to work."
            x[:,:,:-1] = F.softmax(x[:,:,:-1],dim = -1)
        x = torch.scatter(x, -1, indices[..., None], torch.ones_like(x[..., :1]))

        return x