import argparse
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional

from .block_fourier import TetraFourierBlock
from .utils_fourier import ToTetraFourierQuarterBatch, FromTetraFourierQuarterBatch
from .groups import PLATONIC_GROUPS
from .linear import PlatonicLinear
from .platoformer import PlatonicTransformer
from .io import to_dense_and_mask, pool, lift, to_scalars_vectors
from .ape import APE


class TetraFourierTransformer(nn.Module):
    """
    A Transformer architecture equivariant to the symmetries of a specified Platonic solid.

    This model processes point cloud data. It first embeds input node features, then
    "lifts" them into a group-equivariant feature space. A series of PlatonicBlocks
    process these features equivariantly. Finally, for graph-level tasks, it pools
    over the nodes and the group to produce a single invariant prediction. For node-level
    tasks, it pools over the group axis to produce invariant node predictions.

    Args:
        input_dim (int): Dimensionality of the initial node features.
        hidden_dim (int): The per-group-element channel dimension used throughout the model.
        output_dim (int): Dimensionality of the final output.
        nhead (int): Number of attention heads in each PlatonicBlock.
        num_layers (int): Number of PlatonicBlock layers.
        solid_name (str): The name of the Platonic solid ('tetrahedron', 'octahedron',
                          'icosahedron') to define the symmetry group.
        ffn_dim_factor (int): Multiplier for the feed-forward network's hidden dimension,
                              relative to `hidden_dim`.
        scalar_task_level (str): "node" or "graph". Determines the pooling strategy.
        dropout (float): Dropout rate.
        norm_first (bool): If True, use pre-normalization in the blocks.
        drop_path_rate (float): Stochastic depth rate. Default: 0.0.
        layer_scale_init_value (Optional[float]): Initial value for LayerScale. Default: None.
        **kwargs: Additional keyword arguments for the PlatonicBlock layers
    """
    def __init__(self,
                 # Basic/essential specification:
                 input_dim: int,
                 input_dim_vec: int,
                 hidden_dim: int,
                 output_dim: int,
                 output_dim_vec: int,
                 nhead: int,
                 num_layers: int,
                 spatial_dim: int = 3,
                 dense_mode: bool = False, # force dense mode, even if batch is provided
                 # Pooling and readout specification:
                 scalar_task_level: str = "graph",
                 vector_task_level: str = "node",
                 post_pool_readout: bool = True,
                 ffn_readout: bool = True,
                 # Attention block specification:
                 mean_aggregation: bool = False,
                 dropout: float = 0.1,
                 norm_first: bool = True,
                 drop_path_rate: float = 0.0,
                 layer_scale_init_value: Optional[float] = None,
                 attention: bool = False,
                 attention_type: str = 'equivariant',
                 ffn_dim_factor: int = 4,
                 # RoPE and APE specification:
                 rope_sigma: float = 1.0,  # if None it is not used
                 ape_sigma: float = None,  # if None it is not used
                 learned_freqs: bool = True,
                 fourier_type: str = "quarter_batch",
                 **kwargs):
        super().__init__()

        if scalar_task_level not in ["node", "graph"]:
            raise ValueError("scalar_task_level must be 'node' or 'graph'.")

        if vector_task_level not in ["node", "graph"]:
            raise ValueError("vector_task_level must be 'node' or 'graph'.")

        # --- Group and Dimension Setup ---
        solid_name = "tetrahedron"
        self.group = PLATONIC_GROUPS[solid_name.lower()]
        self.num_G = self.group.G
        self.hidden_dim = hidden_dim
        self.scalar_task_level = scalar_task_level
        self.vector_task_level = vector_task_level
        self.dense_mode = dense_mode
        self.output_dim = output_dim
        self.output_dim_vec = output_dim_vec
        self.mean_aggregation = mean_aggregation
        self.post_pool_readout = post_pool_readout

        # Global position embedding for fixed patching ViTs
        if ape_sigma is not None:
            self.ape = APE(hidden_dim, ape_sigma, spatial_dim, learned_freqs)
        else:
            self.register_buffer('ape', None)

        Block = TetraFourierBlock
        if fourier_type == 'standard':
            raise NotImplementedError("TODO")
        elif fourier_type == 'quarter_batch':
            self.to_fourier = ToTetraFourierQuarterBatch()
            self.from_fourier = FromTetraFourierQuarterBatch()
        else:
            raise ValueError()
        self.fourier_type = fourier_type
               
        # --- Modules ---
        # 1. Input Embedding: Applied before lifting to the group.
        # Maps input features to the per-group-element hidden dimension.
        self.x_embedder = PlatonicLinear((input_dim + input_dim_vec * spatial_dim) * self.num_G, self.hidden_dim, solid_name, bias=False)

        # 2. Equivariant Encoder Layers
        # The blocks operate on the total flattened dimension (G * C).
        dim_feedforward = int(self.hidden_dim * ffn_dim_factor)

        allowed_attention_types = [
            'equivariant', 'invariant',                         
            'equivariant-invariant', 'invariant-equivariant',   
            'equivariant-equivariant', 'invariant-invariant'    # Redundant but for consistent format
        ]

        if attention_type not in allowed_attention_types:
            raise ValueError(f"Invalid attention_type: {attention_type}. Must be one of {allowed_attention_types}.")

        if '-' in attention_type:
            first_layer_attn, rest_layers_attn = attention_type.split('-', 1)
        else:
            first_layer_attn = rest_layers_attn = attention_type

        # Create the layers with appropriate attention types
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            # Use first_layer_attn for the first layer, rest_layers_attn for all other layers
            current_attn_type = first_layer_attn if i == 0 else rest_layers_attn
            
            self.layers.append(Block(
                d_model=self.hidden_dim,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                norm_first=norm_first,
                drop_path=drop_path_rate,
                layer_scale_init_value=layer_scale_init_value,
                freq_sigma=rope_sigma,
                learned_freqs=learned_freqs,
                spatial_dims=spatial_dim,
                mean_aggregation=mean_aggregation,
                attention=attention,
                attention_type=current_attn_type,
                fourier_type=self.fourier_type,
                **kwargs
            ))
            
        if ffn_readout:
            self.scalar_readout = nn.Sequential(
                PlatonicLinear(self.hidden_dim, self.hidden_dim, solid_name),
                nn.GELU(),
                PlatonicLinear(self.hidden_dim, self.num_G * output_dim, solid_name)
            )
            
            self.vector_readout = nn.Sequential(
                PlatonicLinear(self.hidden_dim, self.hidden_dim, solid_name),
                nn.GELU(),
                PlatonicLinear(self.hidden_dim, self.hidden_dim, solid_name),
                nn.GELU(),
                PlatonicLinear(self.hidden_dim, self.num_G * output_dim_vec * spatial_dim, solid_name)
            )
        else:
            self.scalar_readout = PlatonicLinear(self.hidden_dim, self.num_G * output_dim, solid_name)
            self.vector_readout = PlatonicLinear(self.hidden_dim, self.num_G * output_dim_vec * spatial_dim, solid_name)

    def forward(self, x: Tensor, pos: Tensor, batch: Optional[torch.Tensor] = None, mask: Optional[Tensor] = None, vec: Optional[Tensor] = None, avg_num_nodes = 1.0) -> Tensor:
        """
        Forward pass for the Platonic Transformer.

        Args:
            x (Tensor): Input node features of shape (N, input_dim).
            pos (Tensor): Node positions of shape (N, spatial_dims).
            batch (Tensor): Batch index for each node of shape (N,).
            mask (Tensor, optional): Attention mask of shape (B, N) or (N, N) for dense inputs.

        Returns:
            Tensor: Final predictions. Shape is (B, output_dim) for graph tasks
                    or (N, output_dim) for node tasks.
        """

        # 1. Convert to dense format if needed
        if self.dense_mode:
            self._input_was_dense_format = (batch is None)
            x, vec, pos, mask = to_dense_and_mask(x, vec, pos, batch)
            batch = None
        else:
            self._input_was_dense_format = False
            mask = None

        # 2. Lift scalars and vectors, then embed
        x = lift(x, vec, self.group)
        x = self.x_embedder(x)  # [..., N, num_patches * C]
        x = x + self.ape(pos) if self.ape is not None else x  # Add absolute position embedding

        # 3. Equivariant Encoder (Platonic Conv Blocks)
        x = self.to_fourier(x)
        for layer in self.layers:
            x = layer(x=x, pos=pos, batch=batch, mask=mask, avg_num_nodes=avg_num_nodes)
        x = self.from_fourier(x)
        # TODO: instead of transforming back here, one could readout irrep quantities which are explicitly scalars and vectors.

        # 4. Pre-pooling readout
        if not self.post_pool_readout:
            scalar_x = self.scalar_readout(x)
            vector_x = self.vector_readout(x)

            if self.scalar_task_level == "graph":
                scalar_x = pool(scalar_x, batch, mask, avg_num_nodes, self.dense_mode, self.mean_aggregation)    
            else:
                if not self._input_was_dense_format and self.dense_mode:
                    scalar_x = scalar_x[mask]

            if self.vector_task_level == "graph":
                vector_x = pool(vector_x, batch, mask, avg_num_nodes, self.dense_mode, self.mean_aggregation)
            else:
                if not self._input_was_dense_format and self.dense_mode:
                    vector_x = vector_x[mask]
        else:
            
            if self.scalar_task_level == "graph" :
                scalar_x = pool(x, batch, mask, avg_num_nodes, self.dense_mode, self.mean_aggregation)
            else:
                if not self._input_was_dense_format and self.dense_mode:
                    scalar_x = x[mask]
                else:
                    scalar_x = x

            if self.vector_task_level == "graph":
                vector_x = pool(x, batch, mask, avg_num_nodes, self.dense_mode, self.mean_aggregation)
            else:
                if not self._input_was_dense_format and self.dense_mode:
                    vector_x = x[mask]
                else:
                    vector_x = x

            scalar_x = self.scalar_readout(scalar_x)
            vector_x = self.vector_readout(vector_x)

        # 7. Extract the scalar and vector parts
        scalars = to_scalars_vectors(scalar_x, self.output_dim, 0, self.group)[0]
        vectors = to_scalars_vectors(vector_x, 0, self.output_dim_vec, self.group)[1]

        # Return final result
        return scalars, vectors

def test_equivalence():
    B = 16
    N = 32
    in_c = 32
    in_c_v = 16
    embed_dim = 288
    num_heads = 4*12
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    with torch.inference_mode():
        x = torch.randn([B, N, in_c], device=device)
        v = torch.randn([B, N, in_c_v, 3], device=device)
        pos = torch.randn([B, N, 3], device=device)
        
        model = PlatonicTransformer(
            input_dim=in_c,
            input_dim_vec=in_c_v,
            hidden_dim=embed_dim,
            output_dim=in_c,
            output_dim_vec=in_c_v,
            nhead=num_heads,
            num_layers=2,
            solid_name="tetrahedron",
            dense_mode=True,
        ).to(device)

        model_f2 = TetraFourierTransformer(
            input_dim=in_c,
            input_dim_vec=in_c_v,
            hidden_dim=embed_dim,
            output_dim=in_c,
            output_dim_vec=in_c_v,
            nhead=num_heads,
            num_layers=2,
            dense_mode=True,
            fourier_type="quarter_batch",
        ).to(device)

        for block, block_f2 in zip(model.layers, model_f2.layers):
            block.norm1 = nn.Identity()  # Norm layers are currently not equivalent between implementations
            block.norm2 = nn.Identity()
            block_f2.interaction.init_from_non_fourier(block.interaction)
            block_f2.linear1.reset_parameters(block.linear1)
            block_f2.linear2.reset_parameters(block.linear2)
            block_f2.norm1 = nn.Identity()  # Norm layers are currently not equivalent between implementations
            block_f2.norm2 = nn.Identity()

        if model.ape is not None:
            model_f2.ape.load_state_dict(model.ape)

        model_f2.x_embedder.load_state_dict(model.x_embedder.state_dict())

        for layer, layer_f2 in zip(model.scalar_readout, model_f2.scalar_readout):
            if isinstance(layer, PlatonicLinear):
                layer_f2.load_state_dict(layer.state_dict())
        for layer, layer_f2 in zip(model.vector_readout, model_f2.vector_readout):
            if isinstance(layer, PlatonicLinear):
                layer_f2.load_state_dict(layer.state_dict())

        model.eval()
        y, y_vec = model(x, pos, vec=v)
        print(y[5,:32])
        print(y_vec[5,2,:8])
        model_f2.eval()
        y_f2, y_vec_f2 = model_f2(x, pos, vec=v)
        print(y_f2[5,:32])
        print(y_vec_f2[5,2,:8])
        # Big atol values below due to no normalization layers
        assert torch.allclose(y, y_f2, atol=5e-2), "Outputs should be equal"
        assert torch.allclose(y_vec, y_vec_f2, atol=5e-2), "Outputs should be equal"

def benchmark_layer(layer, dummy_input, backward=False, num_runs=100, warmups=10):
    """Benchmarks a single layer."""

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    for _ in range(warmups):
        if backward:
            y, _ = layer(dummy_input[0], dummy_input[2], vec=dummy_input[1])
            y.sum().backward()
        else:
            _ = layer(dummy_input[0], dummy_input[2], vec=dummy_input[1])

    torch.cuda.synchronize()
    start_event.record()
    for _ in range(num_runs):
        if backward:
            y, _ = layer(dummy_input[0], dummy_input[2], vec=dummy_input[1])
            y.sum().backward()
        else:
            _ = layer(dummy_input[0], dummy_input[2], vec=dummy_input[1])
    end_event.record()
    torch.cuda.synchronize()
    time_ms = start_event.elapsed_time(end_event) / num_runs

    return time_ms


def print_results(results):
    """Prints a formatted table using f-strings."""
    model_names = list(results["times"].keys())

    max_len = max(len(name) for name in model_names)
    header_len = len("Model")
    col_width = max(max_len, header_len) + 4

    print(f"{'Model':<{col_width}}{'Time (ms)':<12}")
    print("-" * (col_width + 12 + 10))

    for name in model_names:
        time = results["times"][name]
        print(f"{name:<{col_width}}{time:<12.3f}")


def small_speed_benchmark(
    channel_sizes=[64, 96, 128],
    batch_size=64,
    sequence_length=64,
    input_scalars=128,
    output_scalars=128,
    input_vectors=3,
    output_vectors=3,
    num_layers=14,
    num_heads=48,
    matmul_precision="highest",
    dtype=torch.float32,
    torch_compile=False,
    backward=False,
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision(matmul_precision)

    results = {}
    with torch.inference_mode(not backward):
        for hidden_channels in channel_sizes:
            print(f"== Benchmarking for {hidden_channels} channels ==")

            # Initialize layers
            model = PlatonicTransformer(
                input_dim=input_scalars,
                input_dim_vec=input_vectors,
                hidden_dim=12*hidden_channels,
                output_dim=output_scalars,
                output_dim_vec=output_vectors,
                nhead=num_heads,
                num_layers=num_layers,
                solid_name="tetrahedron",
                dense_mode=True,
            ).to(device).to(dtype)

            fourier_model = TetraFourierTransformer(
                input_dim=input_scalars,
                input_dim_vec=input_vectors,
                hidden_dim=12*hidden_channels,
                output_dim=output_scalars,
                output_dim_vec=output_vectors,
                nhead=num_heads,
                num_layers=num_layers,
                fourier_type="quarter_batch",
                dense_mode=True,
            ).to(device).to(dtype)

            if torch_compile:
                model = torch.compile(
                    model,
                    mode="max-autotune-no-cudagraphs",
                )
                fourier_model = torch.compile(
                    fourier_model,
                    mode="max-autotune-no-cudagraphs",
                )

            # Create dummy input data
            scalars = torch.randn([batch_size, sequence_length, input_scalars], device=device, dtype=dtype)
            vectors = torch.randn([batch_size, sequence_length, input_vectors, 3], device=device, dtype=dtype)
            positions = torch.randn([batch_size, sequence_length, 3], device=device, dtype=dtype)

            benchmark_configs = [
                ("PlatonicTransformer", model, (scalars, vectors, positions)),
                ("TetraFourierTransformer", fourier_model, (scalars, vectors, positions)),
            ]

            times = {}

            for name, layer, dummy_input in benchmark_configs:
                torch._dynamo.reset()
                torch.cuda.synchronize()
                layer_time = benchmark_layer(
                    layer,
                    dummy_input,
                    backward=backward,
                )

                times[name] = layer_time

            results[hidden_channels] = {"times": times}
            print_results(results[hidden_channels])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TetraFourierTransformer tests')
    parser.add_argument('--run_correctness_tests', action="store_true")
    parser.add_argument('--run_small_speed_benchmark', action="store_true")
    parser.add_argument('--float32_matmul_precision', type=str, default="highest")
    parser.add_argument('--dtype', type=str, default="float32")
    parser.add_argument('--compile', action="store_true")
    parser.add_argument('--backward', action="store_true")
    args = parser.parse_args()

    dtypes = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
    if args.dtype not in dtypes:
        raise ValueError()

    if args.run_correctness_tests:
        print("<<<< Running correctness tests. >>>>")
        test_equivalence()
        print("<<<< All tests passed! >>>>")

    if args.run_small_speed_benchmark:
        print("<<<< Running small speed benchmark. >>>>")
        small_speed_benchmark(
            matmul_precision=args.float32_matmul_precision,
            dtype=dtypes[args.dtype],
            torch_compile=args.compile,
            backward=args.backward,
        )
