
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops.layers.torch import Rearrange

import torch_geometric
from omegaconf import OmegaConf
from nn import inr
from nn.rt_transformer import *


def inputs_to_flat(weights, biases):
    flat_tensor = torch.cat([w.flatten(1) for w in weights] + [b.flatten(1) for b in biases], dim=1)
    shapes = [w.shape for w in weights] + [b.shape for b in biases]
    return flat_tensor, shapes

def flat_to_inputs(flat_tensor, shapes):
    params = []
    start = 0
    for shape in shapes:
        size = torch.prod(torch.tensor(shape[1:]))
        params.append(flat_tensor[:, start:start + size].view(shape))
        start += size
    return params[:len(params)//2], params[len(params)//2:]

def modulate(x, scale, shift):
    if x.ndim == 3:
        scale = scale.unsqueeze(1)
        shift = shift.unsqueeze(1)
    elif x.ndim == 4:
        scale = scale.unsqueeze(1).unsqueeze(1)
        shift = shift.unsqueeze(1).unsqueeze(1)
    return x * (1 + scale) + shift


class FiLM(nn.Module):
    def __init__(self, d_in, d_out, d_cond) -> None:
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.lin = nn.Linear(d_in, d_out)
        self.scale_shift = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_cond, 2 * d_out),
        )
        self.norm = nn.LayerNorm(d_out)
        self.is_time_cond = True

    def forward(self, x, cond):
        scale, shift = torch.chunk(self.scale_shift(cond), 2, dim=-1)
        return modulate(self.norm(self.lin(x)), scale, shift)


class TimeSequential(nn.Sequential):
    def forward(self, x, t=None):
        for module in self:
            if getattr(module, "is_time_cond", False):
                x = module(x, t)
            else:
                x = module(x)
        return x
    

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def forward(self, t):
        # t = t * 1000
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    

class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, d_out):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Sequential(
            nn.Linear(hidden_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, d_out, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
    

class RTTransformerTime(nn.Module):
    def __init__(
        self,
        d_in,
        d_hid,
        d_out,
        n_layers,
        n_heads,
        layer_layout,
        time_cond_dim=0,
        dropout=0.0,
        node_update_type="rt",
        edge_update_type="rt",
        disable_edge_updates=False,
        rev_edge_features=False,
        num_probe_features=0,
        zero_out_bias=False,
        zero_out_weights=False,
        bias_ln=False,
        weight_ln=False,
        sin_emb=False,
        use_topomask=False,
        input_layers=1,
    ):
        super().__init__()
        self.rev_edge_features = rev_edge_features
        self.nodes_per_layer = layer_layout
        self.construct_graph = GraphConstructor(
            d_in=d_in,
            d_hid=d_hid,
            layer_layout=layer_layout,
            rev_edge_features=rev_edge_features,
            num_probe_features=num_probe_features,
            zero_out_bias=zero_out_bias,
            zero_out_weights=zero_out_weights,
            bias_ln=bias_ln,
            weight_ln=weight_ln,
            sin_emb=sin_emb,
            input_layers=input_layers,
        )

        self.layers = nn.ModuleList(
            [
                # torch.jit.script(
                RTLayerTime(
                    d_hid,
                    n_heads,
                    dropout,
                    node_update_type=node_update_type,
                    edge_update_type=edge_update_type,
                    disable_edge_updates=disable_edge_updates,
                    use_topomask=use_topomask,
                    time_cond_dim=time_cond_dim,
                )
                # )
                for _ in range(n_layers)
            ]
        )
        # self.proj_out_edges = FinalLayer(d_hid, d_hid)
        # self.proj_out_nodes = FinalLayer(d_hid, d_hid)

        self.proj_time = TimestepEmbedder(hidden_size=time_cond_dim)
        self.proj_edge = nn.Linear(d_hid, d_hid)

        self.proj_out_weight = nn.ModuleList([
            FinalLayer(d_hid, d_out) for _ in range(len(layer_layout)-1)])
            # nn.Sequential(
            #     nn.Linear(d_hid, d_hid),
            #     nn.SiLU(),
            #     nn.Linear(d_hid, d_out)
            #     # nn.LayerNorm(proj_dim),
            #     # nn.SiLU(),
            #     # nn.Linear(proj_dim, proj_dim)
            #     ) for _ in range(len(layer_layout)-1)])
        
        self.proj_out_biases = nn.ModuleList([
            FinalLayer(d_hid, d_out) for _ in range(len(layer_layout)-1)])
            # nn.Sequential(
            #     nn.Linear(d_hid, d_hid),
            #     nn.SiLU(),
            #     nn.Linear(d_hid, d_out)
            #     # nn.LayerNorm(proj_dim),
            #     # nn.SiLU(),
            #     # nn.Linear(proj_dim, proj_dim)
            #     ) for _ in range(len(layer_layout)-1)])


        self.initialize_params()

    def initialize_params(self):
        # Initialize timestep embedding MLP:
        nn.init.normal_(self.proj_time.mlp[0].weight, std=0.02)
        nn.init.normal_(self.proj_time.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for l in self.layers:
            if hasattr(l, "adaLN_modulation_edges"):
                nn.init.constant_(l.adaLN_modulation_edges[-1].weight, 0)
                nn.init.constant_(l.adaLN_modulation_edges[-1].bias, 0)
            if hasattr(l, "adaLN_modulation_nodes"):
                nn.init.constant_(l.adaLN_modulation_nodes[-1].weight, 0)
                nn.init.constant_(l.adaLN_modulation_nodes[-1].bias, 0)

        if hasattr(self, "proj_out_edges"):
            nn.init.constant_(self.proj_out_edges.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(self.proj_out_edges.adaLN_modulation[-1].bias, 0)
            nn.init.constant_(self.proj_out_edges.linear.weight, 0)
            nn.init.constant_(self.proj_out_edges.linear.bias, 0)

        if hasattr(self, "proj_out_nodes"):
            nn.init.constant_(self.proj_out_nodes.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(self.proj_out_nodes.adaLN_modulation[-1].bias, 0)
            nn.init.constant_(self.proj_out_nodes.linear.weight, 0)
            nn.init.constant_(self.proj_out_nodes.linear.bias, 0)

        
    def forward(self, inputs, time_cond, input_shapes):
        inputs = flat_to_inputs(inputs, input_shapes)
        node_features, edge_features, mask = self.construct_graph(inputs)
        time_cond = self.proj_time(time_cond)
        # edge_features = self.proj_edge(torch.cat([edge_features, repeat(time_cond, "b d -> b i j d",  i=edge_features.size(1), j=edge_features.size(2))], dim=-1))
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, mask, time_cond)

        # edge_features = self.proj_out_edges(edge_features, time_cond)
        # node_features = self.proj_out_nodes(node_features, time_cond)
        weights, biases = graphs_to_batch(edge_features, node_features, *inputs)
        weights = [proj(w, time_cond) for proj, w in zip(self.proj_out_weight, weights)]
        biases = [proj(b, time_cond) for proj, b in zip(self.proj_out_biases, biases)]
        return inputs_to_flat(weights, biases)[0]

    


class RTLayerTime(nn.Module):
    def __init__(
        self,
        d_hid,
        n_heads,
        dropout,
        node_update_type="rt",
        edge_update_type="rt",
        disable_edge_updates=False,
        use_topomask=False,
        time_cond_dim=0,
    ):
        super().__init__()
        self.d_hid = d_hid
        self.node_update_type = node_update_type
        self.edge_update_type = edge_update_type
        self.disable_edge_updates = disable_edge_updates
        self.time_cond_dim = time_cond_dim

        self.self_attn = torch.jit.script(RTAttention(d_hid, d_hid, d_hid, n_heads, use_topomask=use_topomask))
        # self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads)
        self.lin0 = nn.Linear(d_hid, d_hid)
        self.dropout0 = nn.Dropout(dropout)
        self.node_ln0 = nn.LayerNorm(d_hid)
        self.node_ln1 = nn.LayerNorm(d_hid)
        # if node_update_type == "norm_first":
        #     self.edge_ln0 = nn.LayerNorm(d_hid)

        act_fn = nn.GELU

        use_lin = time_cond_dim == 0 or node_update_type == "rt_modulation"

        self.node_mlp = TimeSequential(
            nn.Linear(d_hid, 2 * d_hid, bias=False) if use_lin else FiLM(d_hid, 2 * d_hid, time_cond_dim),
            act_fn(),
            nn.Linear(2 * d_hid, d_hid),
            nn.Dropout(dropout),
        )

        if node_update_type == "rt_modulation":
            self.adaLN_modulation_nodes = nn.Sequential(
                nn.SiLU(),
                nn.Linear(d_hid, 8 * d_hid, bias=True)
            )
            self.edge_norm0 = nn.LayerNorm(d_hid)

        if not self.disable_edge_updates:
            use_lin = time_cond_dim == 0 or edge_update_type == "rt_modulation"
            if edge_update_type == "rt_modulation":
                self.adaLN_modulation_edges = nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(d_hid, 15 * d_hid, bias=True)
                )
                self.edge_norm1 = nn.LayerNorm(d_hid)
                self.edge_norm2 = nn.LayerNorm(d_hid)
            self.reverse_edge = Rearrange("b n m d -> b m n d")
            self.edge_mlp0 = TimeSequential(
                nn.Linear(4 * d_hid, d_hid, bias=False) if use_lin else FiLM(4 * d_hid, d_hid, time_cond_dim),
                act_fn(),
                nn.Linear(d_hid, d_hid),
                nn.Dropout(dropout),
            )
            self.edge_mlp1 = TimeSequential(
                nn.Linear(d_hid, 2 * d_hid, bias=False) if use_lin else FiLM(d_hid, 2 * d_hid, time_cond_dim),
                act_fn(),
                nn.Linear(2 * d_hid, d_hid),
                nn.Dropout(dropout),
            )
            self.eln0 = nn.LayerNorm(d_hid)
            self.eln1 = nn.LayerNorm(d_hid)

    def node_updates(self, node_features, edge_features, mask, time_cond):
        if self.node_update_type == "norm_first":
            node_features = node_features + self.self_attn(
                self.node_ln0(node_features), edge_features, mask
            )
            node_features = node_features + self.node_mlp(self.node_ln1(node_features))
        elif self.node_update_type == "norm_last":
            node_features = self.node_ln0(
                node_features + self.self_attn(node_features, edge_features, mask)
            )
            node_features = self.node_ln1(node_features + self.node_mlp(node_features))
        elif self.node_update_type == "rt":
            # attn_out = checkpoint(self.self_attn, node_features, edge_features, mask)
            node_features = self.node_ln0(
                node_features
                + self.dropout0(
                    self.lin0(self.self_attn(node_features, edge_features, mask), )
                )
            )
            node_features = self.node_ln1(node_features + self.node_mlp(node_features, time_cond))
        elif self.node_update_type == "rt_modulation":
            (shift_msa_n, scale_msa_n, shift_msa_e, scale_msa_e, 
                gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation_nodes(time_cond).chunk(8, dim=1)
            mnf = modulate(self.node_ln0(node_features), shift_msa_n, scale_msa_n)
            mef = modulate(self.edge_norm0(edge_features), shift_msa_e, scale_msa_e)
            node_features = node_features + gate_msa.unsqueeze(1) * self.self_attn(mnf, mef, mask)
            node_features = (node_features + gate_mlp.unsqueeze(1) * 
                self.node_mlp(modulate(self.node_ln1(node_features), shift_mlp, scale_mlp)))
        else:
            raise ValueError(f"Unknown node update type: {self.node_update_type}")
        return node_features

    def edge_updates(self, node_features, edge_features, mask, time_cond):
        source_nodes = node_features.unsqueeze(-2).expand(
            -1, -1, node_features.size(-2), -1
        )
        target_nodes = node_features.unsqueeze(-3).expand(
            -1, node_features.size(-2), -1, -1
        )
        reversed_edge_features = self.reverse_edge(edge_features)
        input_features = torch.cat(
            [edge_features, reversed_edge_features, source_nodes, target_nodes],
            dim=-1,
        )
        if self.edge_update_type == "rt":
            edge_features = self.eln0(edge_features + self.edge_mlp0(input_features, time_cond))
            edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features, time_cond))
        elif self.edge_update_type == "rt_modulate":
            (shift0, scale0, gate0, shift1, 
                scale1, gate1) = self.adaLN_modulation_edges(time_cond).split(
                [input_features.size(-1)] * 2 + [edge_features.size(-1)] * 4, dim=1)
            edge_features = edge_features + gate0 * self.edge_mlp0(
                modulate(self.edge_norm1(edge_features), shift0, scale0))
            edge_features = edge_features + gate1 * self.edge_mlp1(
                modulate(self.edge_norm2(edge_features), shift1, scale1))
        return edge_features

    def forward(self, node_features, edge_features, mask, time_cond=None):
        node_features = self.node_updates(node_features, edge_features, mask, time_cond)

        if not self.disable_edge_updates:
            edge_features = self.edge_updates(node_features, edge_features, mask, time_cond)

        return node_features, edge_features