import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math
from typing import List, Any
from numbers import Number

from .groups import PLATONIC_GROUPS
from .linear import PlatonicLinear
from .utils_fourier import (
    basis_tetra_regular_to_fourier,
    ToTetraFourier,
    FromTetraFourier,
    SparseToTetraFourier,
    SparseFromTetraFourier,
    ToTetraFourierQuarterBatch,
    FromTetraFourierQuarterBatch,
)


class TetraFourierLinear(nn.Module):
    """
        A Linear layer constrained to be a group convolution over a the tetrahedral group, implemented in the Fourier domain.
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True, transform_to_fourier: bool = True, transform_back_from_fourier: bool = True):
        super().__init__()

        solid = "tetrahedron"
        group = PLATONIC_GROUPS[solid]
        self.G = group.G
        self.in_features = in_features
        self.out_features = out_features
        self.transform_to_fourier = transform_to_fourier
        self.transform_back_from_fourier = transform_back_from_fourier
        if self.transform_to_fourier:
            self.to_fourier = ToTetraFourier()
        if self.transform_back_from_fourier:
            self.from_fourier = FromTetraFourier()

        if in_features % self.G != 0:
            raise ValueError(f"in_features ({in_features}) must be divisible by the group order {self.G} for solid '{solid}'")
        if out_features % self.G != 0:
            raise ValueError(f"out_features ({out_features}) must be divisible by the group order {self.G} for solid '{solid}'")

        self.in_channels = in_features // self.G
        self.out_channels = out_features // self.G

        self.linear1 = nn.Linear(self.in_channels, self.out_channels, bias=bias)  # maps invariant irrep to itself (bias only allowed here)
        self.weight2_1 = nn.Parameter(torch.empty(self.out_channels, self.in_channels))  # maps first part of 2D-irrep to first and second part of 2D-irrep to second
        self.weight2_2 = nn.Parameter(torch.empty(self.out_channels, self.in_channels))  # maps second part of 2D-irrep to negative first and first part of 2D-irrep to second part
        self.linear3 = nn.Linear(3 * self.in_channels, 3 * self.out_channels, bias=False)  # maps 3 copies of 3D-irrep to 3 copies of 3D-irrep

        self.reset_parameters()

    def reset_parameters(self, tmp_tetra_linear: PlatonicLinear = None) -> None:
        """
        Initialize the weights and bias to be equivalent to a corresponding non-fourier linear layer.
        If `tmp_tetra_linear` is given, the weights are taken from that layer, otherwise a new one is constructed.
        """
        if tmp_tetra_linear is None:
            tmp_tetra_linear = PlatonicLinear(self.in_features, self.out_features, "tetrahedron", bias=self.linear1.bias is not None)
        elif tmp_tetra_linear.in_features != self.in_features or tmp_tetra_linear.out_features != self.out_features or tmp_tetra_linear.G != self.G:
            raise ValueError()
        if self.linear1.bias is not None:
            if tmp_tetra_linear.bias is None:
                raise ValueError()
            with torch.no_grad():
                self.linear1.bias.copy_(math.sqrt(self.G) * tmp_tetra_linear.bias.data)

        C_in = self.in_channels
        C_out = self.out_channels

        tmp_w = tmp_tetra_linear.get_weight().clone().detach()
        basis_regular_to_fourier = basis_tetra_regular_to_fourier().to(tmp_w.device)
        # The change of basis should be made accurately, so we use highest precision
        precision = torch.get_float32_matmul_precision()
        try:
            torch.set_float32_matmul_precision("highest")
            w_fourier = torch.einsum(
                "gh,hdic,ij->gdjc",
                basis_regular_to_fourier,
                tmp_w.reshape(self.G, C_out, self.G, C_in),
                basis_regular_to_fourier.mT,
            ).reshape(self.out_features, self.in_features)
        finally:
            torch.set_float32_matmul_precision(precision)

        # in total there are 12 * C_in * C_out free parameters, they have now been reparameterized to be on a block diagonal and are extracted next
        w_1 = w_fourier[:C_out, :C_in]  # C_out * C_in
        w_2_1 = w_fourier[C_out:2*C_out, C_in:2*C_in]  # C_out * C_in, this is the same as the slice 2*C_out:3*C_out, 2*C_in:3*C_in
        w_2_2 = w_fourier[2*C_out:3*C_out, C_in:2*C_in]  # C_out * C_in, this is the negative of the slice C_out:2*C_out, 2*C_in:3*C_in
        # Warning: the 2D irrep is ultra annoying because its endomorphism ring is isomorphic to the complex numbers, hence this form of the 2x2 weight block
        w_2 = torch.cat((
            torch.cat((w_2_1, -w_2_2), dim=1),
            torch.cat((w_2_2,  w_2_1), dim=1),
        ), dim=0)
        w_3 = w_fourier[3*C_out:6*C_out, 3*C_in:6*C_in]  # (3 * C_out) * (3 * C_in)
        assert torch.allclose(w_fourier, torch.block_diag(w_1, w_2, w_3, w_3, w_3), atol=1e-6), "Transformed weight should be block diagonal of a specific form."

        with torch.no_grad():
            self.linear1.weight.copy_(w_1)
            self.weight2_1.copy_(w_2_1)
            self.weight2_2.copy_(w_2_2)
            self.linear3.weight.copy_(w_3)

    def get_2d_weight(self):
        return torch.cat((
            torch.cat((self.weight2_1, -self.weight2_2), dim=1),
            torch.cat((self.weight2_2,  self.weight2_1), dim=1)
        ), dim=0)

    def linear2(self, x):
        return F.linear(x, self.get_2d_weight())

    def forward(self, x):
        """Applies the group-equivariant linear transformation."""
        if self.transform_to_fourier:
            x1, x2, x3 = self.to_fourier(x)
        else:
            x1, x2, x3 = x

        y1 = self.linear1(x1)
        y2 = self.linear2(x2)
        y3 = self.linear3(x3)

        if self.transform_back_from_fourier:
            return self.from_fourier(y1, y2, y3)
        else:
            return (y1, y2, y3)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(G={self.G}, in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None})"



class TetraFourierLinearQuarterBatch(nn.Module):
    """
        A Linear layer constrained to be a group convolution over a the tetrahedral group, implemented in the Fourier domain.
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True, transform_to_fourier: bool = True, transform_back_from_fourier: bool = True):
        super().__init__()

        solid = "tetrahedron"
        group = PLATONIC_GROUPS[solid]
        self.G = group.G
        self.in_features = in_features
        self.out_features = out_features
        self.transform_to_fourier = transform_to_fourier
        self.transform_back_from_fourier = transform_back_from_fourier
        if self.transform_to_fourier:
            self.to_fourier = ToTetraFourierQuarterBatch()
        if self.transform_back_from_fourier:
            self.from_fourier = FromTetraFourierQuarterBatch()

        if in_features % self.G != 0:
            raise ValueError(f"in_features ({in_features}) must be divisible by the group order {self.G} for solid '{solid}'")
        if out_features % self.G != 0:
            raise ValueError(f"out_features ({out_features}) must be divisible by the group order {self.G} for solid '{solid}'")

        self.in_channels = in_features // self.G
        self.out_channels = out_features // self.G

        self.weight_1d = nn.Parameter(
            torch.empty(self.out_channels, self.in_channels)
        )
        self.weight_2d_1 = nn.Parameter(
            torch.empty(self.out_channels, self.in_channels)
        )
        self.weight_2d_2 = nn.Parameter(
            torch.empty(self.out_channels, self.in_channels)
        )
        self.weight_3d = nn.Parameter(
            torch.empty(3 * self.out_channels, 3 * self.in_channels)
        )

        if bias:
            self.bias = nn.Parameter(torch.empty(self.out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self, tmp_tetra_linear: PlatonicLinear = None) -> None:
        """
        Initialize the weights and bias to be equivalent to a corresponding non-fourier linear layer.
        If `tmp_tetra_linear` is given, the weights are taken from that layer, otherwise a new one is constructed.
        """
        if tmp_tetra_linear is None:
            tmp_tetra_linear = PlatonicLinear(self.in_features, self.out_features, "tetrahedron", bias=self.bias is not None)
        elif tmp_tetra_linear.in_features != self.in_features or tmp_tetra_linear.out_features != self.out_features or tmp_tetra_linear.G != self.G:
            raise ValueError()
        if self.bias is not None:
            if tmp_tetra_linear.bias is None:
                raise ValueError()
            with torch.no_grad():
                self.bias.copy_(math.sqrt(self.G) * tmp_tetra_linear.bias.data)

        C_in = self.in_channels
        C_out = self.out_channels

        tmp_w = tmp_tetra_linear.get_weight().clone().detach()
        basis_regular_to_fourier = basis_tetra_regular_to_fourier().to(tmp_w.device)
        # The change of basis should be made accurately, so we use highest precision
        precision = torch.get_float32_matmul_precision()
        try:
            torch.set_float32_matmul_precision("highest")
            w_fourier = torch.einsum(
                "gh,hdic,ij->gdjc",
                basis_regular_to_fourier,
                tmp_w.reshape(self.G, C_out, self.G, C_in),
                basis_regular_to_fourier.mT,
            ).reshape(self.out_features, self.in_features)
        finally:
            torch.set_float32_matmul_precision(precision)

        # in total there are 12 * C_in * C_out free parameters, they have now been reparameterized to be on a block diagonal and are extracted next
        w_1 = w_fourier[:C_out, :C_in]  # C_out * C_in
        w_2_1 = w_fourier[C_out:2*C_out, C_in:2*C_in]  # C_out * C_in, this is the same as the slice 2*C_out:3*C_out, 2*C_in:3*C_in
        w_2_2 = w_fourier[2*C_out:3*C_out, C_in:2*C_in]  # C_out * C_in, this is the negative of the slice C_out:2*C_out, 2*C_in:3*C_in
        # Warning: the 2D irrep is ultra annoying because its endomorphism ring is isomorphic to the complex numbers, hence this form of the 2x2 weight block
        w_2 = torch.cat((
            torch.cat((w_2_1, -w_2_2), dim=1),
            torch.cat((w_2_2,  w_2_1), dim=1),
        ), dim=0)
        w_3 = w_fourier[3*C_out:6*C_out, 3*C_in:6*C_in]  # (3 * C_out) * (3 * C_in)
        assert torch.allclose(w_fourier, torch.block_diag(w_1, w_2, w_3, w_3, w_3), atol=1e-6), "Transformed weight should be block diagonal of a specific form."

        with torch.no_grad():
            self.weight_1d.copy_(w_1)
            self.weight_2d_1.copy_(w_2_1)
            self.weight_2d_2.copy_(w_2_2)
            self.weight_3d.copy_(w_3)

    def get_2d_weight(self):
        return torch.cat((
            torch.cat((self.weight_2d_1, -self.weight_2d_2), dim=1),
            torch.cat((self.weight_2d_2,  self.weight_2d_1), dim=1)
        ), dim=0)

    def get_1d_and_2d_weight(self):
        return torch.block_diag(self.weight_1d, self.get_2d_weight())

    def get_quarter_batched_weight(self):
        return torch.vstack((
            self.get_1d_and_2d_weight()[None],
            self.weight_3d[None].expand(3, -1, -1),
        ))

    def forward(self, x):
        """Applies the group-equivariant linear transformation."""
        if self.transform_to_fourier:
            x = self.to_fourier(x)

        x = torch.einsum('ijk,...ik->...ij', self.get_quarter_batched_weight(), x)
        # batch_dims = x.shape[:-2]
        # x = torch.bmm(
        #     self.get_quarter_batched_weight(),
        #     x.view(-1, 4, 3*self.in_channels).permute(1, 2, 0)
        # ).permute(2, 0, 1).view(*batch_dims, 4, 3*self.out_channels)

        if self.bias is not None:
            x[..., 0, :self.out_channels] = x[..., 0, :self.out_channels] + self.bias

        if self.transform_back_from_fourier:
            x = self.from_fourier(x)
        return x

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(G={self.G}, in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None})"


def correctness_tests():
    print("=== Testing orthogonality of basis. ===")
    assert torch.allclose(basis_tetra_regular_to_fourier() @ basis_tetra_regular_to_fourier().mT, torch.eye(12), atol=1e-6), "Basis should be orthogonal."

    print("=== Testing consistency with original implementation. ===")
    C_in = 256
    C_out = 192
    G = 12
    B = 32
    N = 256
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision("highest")

    with torch.inference_mode():
        print("------ Testing consistency for TetraFourierLinear. ------")
        for _ in range(10):
            tetra_linear = PlatonicLinear(G*C_in, G*C_out, "tetrahedron").to(device)
            fourier_tetra_linear = TetraFourierLinear(G*C_in, G*C_out).to(device)
            fourier_tetra_linear.reset_parameters(tmp_tetra_linear=tetra_linear)
            x = torch.randn([B, N, G*C_in], device=device)
            y = tetra_linear(x)
            fy = fourier_tetra_linear(x)
            assert torch.allclose(y, fy, atol=1e-5), "Outputs should be the same."

        print("------ Testing consistency for TetraFourierLinearQuarterBatch. ------")
        for _ in range(10):
            tetra_linear = PlatonicLinear(G*C_in, G*C_out, "tetrahedron").to(device)
            fourier_tetra_linear_quarters = TetraFourierLinearQuarterBatch(G*C_in, G*C_out).to(device)
            fourier_tetra_linear_quarters.reset_parameters(tmp_tetra_linear=tetra_linear)
            x = torch.randn([B, N, G*C_in], device=device)
            y = tetra_linear(x)
            fy = fourier_tetra_linear_quarters(x)
            assert torch.allclose(y, fy, atol=1e-5), "Outputs should be the same."


def benchmark_layer(layer, dummy_input, calculate_flops=True, num_runs=100, warmups=10):
    """Benchmarks a single layer."""

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    for _ in range(warmups):
        _ = layer(dummy_input)

    torch.cuda.synchronize()
    start_event.record()
    for _ in range(num_runs):
        _ = layer(dummy_input)
    end_event.record()
    torch.cuda.synchronize()
    time_ms = start_event.elapsed_time(end_event) / num_runs

    if calculate_flops:
        # TODO: Potentially add unsupported (often actually negligible) operators to flop analysis
        from fvcore.nn.jit_handles import get_shape
        from fvcore.nn import FlopCountAnalysis
        flops = FlopCountAnalysis(layer, (dummy_input,)).total()
    else:
        flops = np.nan

    return time_ms, flops


def print_results(results):
    """Prints a formatted table using f-strings."""
    model_names = list(results["times"].keys())

    max_len = max(len(name) for name in model_names)
    header_len = len("Model")
    col_width = max(max_len, header_len) + 4

    print(f"{'Model':<{col_width}}{'Time (ms)':<12}{'GFLOPs':<10}")
    print("-" * (col_width + 12 + 10))

    for name in model_names:
        time = results["times"][name]
        gflops = results["flops"][name] / 1e9
        print(f"{name:<{col_width}}{time:<12.3f}{gflops:<10.2f}")


def small_speed_benchmark(
    batch_size=128,
    sequence_length=128,
    channel_sizes=[64, 128, 256],
    matmul_precision="highest",
    dtype=torch.float32,
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision(matmul_precision)

    class Wrapper(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, x):
            return self.module(x[0], x[1], x[2])

    results = {}
    with torch.inference_mode():
        for C_in in channel_sizes:
            print(f"== Benchmarking for {C_in} channels ==")
            in_features = 12 * C_in
            out_features = 12 * C_in

            # Initialize layers
            standard_linear = nn.Linear(in_features, out_features).to(device).to(dtype)

            to_fourier = ToTetraFourier().to(device).to(dtype)

            from_fourier = Wrapper(FromTetraFourier()).to(device).to(dtype)

            # sparse_to_fourier = SparseToTetraFourier(C_in).to(device).to(dtype)

            # sparse_from_fourier = Wrapper(SparseFromTetraFourier(C_in)).to(device).to(dtype)

            fourier_tetra_linear = torch.compile(
                TetraFourierLinear(in_features, out_features).to(device).to(dtype),
                mode="max-autotune-no-cudagraphs",
            )

            fourier_tetra_linear_no_pre_transform = TetraFourierLinear(in_features, out_features, transform_to_fourier=False, transform_back_from_fourier=True).to(device).to(dtype)

            fourier_tetra_linear_no_post_transform = TetraFourierLinear(
                in_features, out_features,
                transform_to_fourier=True, transform_back_from_fourier=False,
            ).to(device).to(dtype)

            fourier_tetra_linear_no_transform = TetraFourierLinear(
                in_features, out_features, 
                transform_to_fourier=False, transform_back_from_fourier=False,
            ).to(device).to(dtype)

            to_fourier_qb = ToTetraFourierQuarterBatch().to(device).to(dtype)

            from_fourier_qb = FromTetraFourierQuarterBatch().to(device).to(dtype)

            fourier_tetra_linear_qb = torch.compile(
                TetraFourierLinearQuarterBatch(in_features, out_features).to(device).to(dtype),
                mode="max-autotune-no-cudagraphs",
            )

            fourier_tetra_linear_qb_no_transform = TetraFourierLinearQuarterBatch(
                in_features, out_features,
                transform_to_fourier=False, transform_back_from_fourier=False,
            ).to(device).to(dtype)

            # Create dummy input data
            x = torch.randn([batch_size, sequence_length, in_features], device=device, dtype=dtype)
            # Create dummy input for the Fourier layer without transform
            x_fourier = (
                torch.randn([batch_size, sequence_length, 1, C_in], device=device, dtype=dtype),
                torch.randn([batch_size, sequence_length, 2*C_in], device=device, dtype=dtype),
                torch.randn([batch_size, sequence_length, 3, 3*C_in], device=device, dtype=dtype),
            )
            x_fourier_qb = torch.randn([batch_size, sequence_length, 4, 3*C_in], device=device, dtype=dtype)

            benchmark_configs = [
                ("Linear (non-equivariant)", standard_linear, x),
                ("ToFourier", to_fourier, x),
                ("FromFourier", from_fourier, x_fourier),
                # ("SparseToFourier", sparse_to_fourier, x),
                # ("SparseFromFourier", sparse_from_fourier, x_fourier),
                ("Compiled Fourier", fourier_tetra_linear, x),
                ("Fourier (no pre-transform)", fourier_tetra_linear_no_pre_transform, x_fourier),
                ("Fourier (no post-transform)", fourier_tetra_linear_no_post_transform, x),
                ("Fourier (no transform)", fourier_tetra_linear_no_transform, x_fourier),
                ("ToFourierQuarterBatch", to_fourier_qb, x),
                ("FromFourierQuarterBatch", from_fourier_qb, x_fourier_qb),
                ("Compiled FourierQuarterBatch", fourier_tetra_linear_qb, x),
                ("FourierQuarterBatch (no transform)", fourier_tetra_linear_qb_no_transform, x_fourier_qb),
            ]

            times = {}
            flops = {}

            for name, layer, dummy_input in benchmark_configs:
                torch._dynamo.reset()
                torch.cuda.synchronize()
                layer_time, layer_flops = benchmark_layer(
                    layer,
                    dummy_input,
                    calculate_flops=("Compiled" not in name and "Sparse" not in name),
                )

                times[name] = layer_time
                flops[name] = layer_flops

            results[C_in] = {"times": times, "flops": flops}
            print_results(results[C_in])


def speed_benchmark(
    batch_size=128,
    sequence_length=128,
    channel_sizes=[32, 64, 128, 256, 512],
    matmul_precision="highest",
    dtype=torch.float32,
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision(matmul_precision)

    results = {}
    with torch.inference_mode():
        for C_in in channel_sizes:
            print(f"== Benchmarking for {C_in} channels ==")
            C_out = C_in
            in_features = 12 * C_in
            out_features = 12 * C_out

            # Initialize layers
            standard_linear = nn.Linear(in_features, out_features).to(device).to(dtype) # Standard Linear Layer

            tetra_linear = PlatonicLinear(in_features, out_features, "tetrahedron").to(device).to(dtype)

            fourier_tetra_linear = TetraFourierLinear(in_features, out_features).to(device).to(dtype)
            fourier_tetra_linear_no_transform = TetraFourierLinear(in_features, out_features, transform_to_fourier=False, transform_back_from_fourier=False).to(device).to(dtype)

            fourier_tetra_linear_qb = TetraFourierLinearQuarterBatch(in_features, out_features).to(device).to(dtype)
            fourier_tetra_linear_no_transform_qb = TetraFourierLinearQuarterBatch(in_features, out_features, transform_to_fourier=False, transform_back_from_fourier=False).to(device).to(dtype)

            # Compile layers
            compiled_standard_linear = torch.compile(standard_linear, dynamic=False, mode="max-autotune-no-cudagraphs")
            compiled_tetra_linear = torch.compile(tetra_linear, dynamic=False, mode="max-autotune-no-cudagraphs")
            compiled_fourier_tetra_linear = torch.compile(fourier_tetra_linear, dynamic=False, mode="max-autotune-no-cudagraphs")
            compiled_fourier_no_transform = torch.compile(fourier_tetra_linear_no_transform, dynamic=False, mode="max-autotune-no-cudagraphs")
            compiled_fourier_tetra_linear_qb = torch.compile(fourier_tetra_linear_qb, dynamic=False, mode="max-autotune-no-cudagraphs")
            compiled_fourier_no_transform_qb = torch.compile(fourier_tetra_linear_no_transform_qb, dynamic=False, mode="max-autotune-no-cudagraphs")

            # Create dummy input data
            x = torch.randn([batch_size, sequence_length, in_features], device=device, dtype=dtype)
            # Create dummy input for the Fourier layer without transform
            x_fourier_input = fourier_tetra_linear.to_fourier(x)
            x_fourier_input_qb = fourier_tetra_linear_qb.to_fourier(x)

            benchmark_configs = [
                ("Linear (non-equivariant)", standard_linear, x),
                ("Compiled Linear (non-equivariant)", compiled_standard_linear, x),
                ("PlatonicLinear", tetra_linear, x),
                ("Compiled Platonic", compiled_tetra_linear, x),
                ("Fourier", fourier_tetra_linear, x),
                ("Compiled Fourier", compiled_fourier_tetra_linear, x),
                ("Fourier (no transform)", fourier_tetra_linear_no_transform, x_fourier_input),
                ("Compiled Fourier (no transform)", compiled_fourier_no_transform, x_fourier_input),
                ("FourierQB", fourier_tetra_linear_qb, x),
                ("Compiled FourierQB", compiled_fourier_tetra_linear_qb, x),
                ("FourierQB (no transform)", fourier_tetra_linear_no_transform_qb, x_fourier_input_qb),
                ("Compiled FourierQB (no transform)", compiled_fourier_no_transform_qb, x_fourier_input_qb),
            ]

            times = {}
            flops = {}

            for name, layer, dummy_input in benchmark_configs:
                torch._dynamo.reset()
                torch.cuda.synchronize()
                layer_time, layer_flops = benchmark_layer(
                    layer,
                    dummy_input,
                    calculate_flops=("Compiled" not in name),
                )

                times[name] = layer_time
                flops[name] = layer_flops

            results[C_in] = {"times": times, "flops": flops}
    return results


def plot_speed_benchmark_results(results, plot_path):
    import matplotlib.pyplot as plt
    import pandas as pd
    import matplotlib.cm as cm # Import colormap module

    times_data = {}
    flops_data = {}

    for channel_size, metrics in results.items():
        if 'times' in metrics:
            times_data[channel_size] = metrics['times']
        if 'flops' in metrics:
            flops_data[channel_size] = metrics['flops']

    df_times = pd.DataFrame.from_dict(times_data, orient='index')
    df_times.index.name = 'Channel Size'
    df_flops = pd.DataFrame.from_dict(flops_data, orient='index')
    df_flops.index.name = 'Channel Size'

    colors = plt.colormaps.get_cmap('Dark2')
    color_cycle = [colors(i) for i in [2, 5, 1, 7, 3, 4]]

    fig, axes = plt.subplots(3, 1, figsize=(8, 12))

    # Plotting Times (Non-Compiled)
    ax1 = axes[0]
    ax1.set_prop_cycle(color=color_cycle)
    ax1.plot(df_times.index, df_times['Linear (non-equivariant)'], marker='o', linestyle='-', label='Linear (non-equivariant)')
    ax1.plot(df_times.index, df_times['PlatonicLinear'], marker='o', linestyle='--', label='PlatonicLinear')
    ax1.plot(df_times.index, df_times['Fourier'], marker='o', linestyle='-', label='Fourier')
    ax1.plot(df_times.index, df_times['Fourier (no transform)'], marker='o', linestyle='--', label='Fourier (no transform)')
    ax1.plot(df_times.index, df_times['FourierQB'], marker='o', linestyle='-', label='FourierQB')
    ax1.plot(df_times.index, df_times['FourierQB (no transform)'], marker='o', linestyle='--', label='FourierQB (no transform)')

    ax1.set_xlabel('Channel Size')
    ax1.set_ylabel('Execution Time (ms)')
    ax1.set_title('Layer Execution Time vs. Channel Size (Non-Compiled)')
    ax1.legend()
    ax1.grid(True)
    ax1.set_yscale('log', base=2)
    ax1.set_xscale('log', base=2)
    ax1.set_xticks(df_times.index, labels=df_times.index)

    # Plotting Compiled Times
    ax2 = axes[1]
    ax2.set_prop_cycle(color=color_cycle)
    ax2.plot(df_times.index, df_times['Compiled Linear (non-equivariant)'], marker='o', linestyle='-', label='Compiled Linear (non-equivariant)')
    ax2.plot(df_times.index, df_times['Compiled Platonic'], marker='o', linestyle='--', label='Compiled Platonic')
    ax2.plot(df_times.index, df_times['Compiled Fourier'], marker='o', linestyle='-', label='Compiled Fourier')
    ax2.plot(df_times.index, df_times['Compiled Fourier (no transform)'], marker='o', linestyle='--', label='Compiled Fourier (no transform)')
    ax2.plot(df_times.index, df_times['Compiled FourierQB'], marker='o', linestyle='-', label='Compiled FourierQB')
    ax2.plot(df_times.index, df_times['Compiled FourierQB (no transform)'], marker='o', linestyle='--', label='Compiled FourierQB (no transform)')

    ax2.set_xlabel('Channel Size')
    ax2.set_ylabel('Execution Time (ms)')
    ax2.set_title('Layer Execution Time vs. Channel Size (Compiled)')
    ax2.legend()
    ax2.grid(True)
    ax2.set_yscale('log', base=2)
    ax2.set_xscale('log', base=2)
    ax2.set_xticks(df_times.index, labels=df_times.index)

    # Plotting FLOPs (Non-compiled layers only)
    ax3 = axes[2]
    ax3.set_prop_cycle(color=color_cycle)

    ax3.plot(df_flops.index, df_flops['Linear (non-equivariant)'], marker='o', linestyle='-', label='Linear (non-equivariant) FLOPs')
    ax3.plot(df_flops.index, df_flops['PlatonicLinear'], marker='o', linestyle='--', label='PlatonicLinear FLOPs')
    ax3.plot(df_flops.index, df_flops['Fourier'], marker='o', linestyle='-', label='Fourier FLOPs')
    ax3.plot(df_flops.index, df_flops['Fourier (no transform)'], marker='o', linestyle='--', label='Fourier (no transform) FLOPs')
    ax3.plot(df_flops.index, df_flops['FourierQB'], marker='o', linestyle='-', label='FourierQB FLOPs')
    ax3.plot(df_flops.index, df_flops['FourierQB (no transform)'], marker='o', linestyle='--', label='FourierQB (no transform) FLOPs')

    ax3.set_xlabel('Channel Size')
    ax3.set_ylabel('FLOPs')
    ax3.set_title('Layer FLOPs vs. Channel Size (Non-Compiled)')
    ax3.legend()
    ax3.grid(True)
    ax3.set_yscale('log', base=2)
    ax3.set_xscale('log', base=2)
    ax3.set_xticks(df_flops.index, labels=df_flops.index) # Ensure correct x-axis labels

    plt.tight_layout()
    plt.savefig(plot_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TetraFourierLinear tests')
    parser.add_argument('--run_correctness_tests', action="store_true")
    parser.add_argument('--run_small_speed_benchmark', action="store_true")
    parser.add_argument('--run_full_speed_benchmark', action="store_true")
    parser.add_argument('--plot_path', type=str, default="fourier_linear_bench.png")
    parser.add_argument('--float32_matmul_precision', type=str, default="highest")
    parser.add_argument('--dtype', type=str, default="float32")
    args = parser.parse_args()

    dtypes = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
    if args.dtype not in dtypes:
        raise ValueError()

    if args.run_correctness_tests:
        print("<<<< Running correctness tests. >>>>")
        correctness_tests()
        print("<<<< All tests passed! >>>>")

    if args.run_small_speed_benchmark:
        print("<<<< Running small speed benchmark. >>>>")
        small_speed_benchmark(matmul_precision=args.float32_matmul_precision, dtype=dtypes[args.dtype])
    if args.run_full_speed_benchmark:
        print("<<<< Running full speed benchmark. >>>>")
        results = speed_benchmark(matmul_precision=args.float32_matmul_precision, dtype=dtypes[args.dtype])
        print(f"<<<< Plotting results to {args.plot_path}. >>>>")
        plot_speed_benchmark_results(results, plot_path=args.plot_path)
