import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import DTYPES, no_train


DEVICE_CONFIGS = {
    32: {
        "lambda_value": 0.00001,
        "width": 0.00064,
        "lens_z": 0.00064,
    },
    64: {
        "lambda_value": 0.00001,
        "width": 0.00128,
        "lens_z": 0.00128,
    },
    128: {
        "lambda_value": 0.00001,
        "width": 0.005,
        "lens_z": 0.00976562,
    },
    256: {
        "lambda_value": 0.00001,
        "width": 0.01,
        "lens_z": 0.0195312,
    }
}


def get_indices(input_shape, stride=(1, 1), resolution=32, flipped=False):
    half = resolution // 2
    h, w = input_shape
    half_h, half_w = h // 2, w // 2
    sh, sw = stride
    indices = []
    if flipped:
        hrange = range(half - 1 + half_h, half - 1 - h + half_h, -sh)
        wrange = range(half - 1 + half_w, half - 1 - w + half_w, -sw)
    else:
        hrange = range(half - half_h, half + h - half_h, sh)
        wrange = range(half - half_w, half + w - half_w, sw)
    for i in hrange:
        for j in wrange:
            indices.append(i * resolution + j)
    return indices


@torch.no_grad()
def construct_linear_transform(resolution=32, device="cuda", dtype="fp16"):
    config = DEVICE_CONFIGS[resolution]
    lambda_value, width, lens_z = config["lambda_value"], config["width"], config["lens_z"]
    k_value = 2 * torch.pi / lambda_value
    patch_size = min(resolution, 8192 // resolution)  # 8192 for a6000

    def construct_lens(m, n, i, j, after_lens):
        m, n, i, j = torch.meshgrid(m, n, i, j)
        r2 = lens_z * lens_z + torch.square(i - m) + torch.square(j - n)
        phase = k_value * torch.sqrt(r2)
        if after_lens:
            phase -= k_value / 2 / lens_z * (torch.square(m) + torch.square(n))
        intensity = width * width / (resolution * 2) / (resolution * 2) * lens_z / lambda_value / r2
        c, s = intensity * torch.cos(phase), intensity * torch.sin(phase)
        return torch.stack([torch.stack([c, -s], dim=0), torch.stack([s, c], dim=0)], dim=3)

    result = torch.empty(2, resolution, resolution, 2, resolution, resolution, device="cpu", dtype=DTYPES[dtype])

    base2x = (torch.arange(resolution * 2, device=device, dtype=torch.float64) + 0.5) * width / (resolution * 2) - width / 2
    base = base2x[resolution // 2: resolution // 2 * 3]
    for m_offset in range(0, resolution, patch_size):
        for n_offset in range(0, resolution, patch_size):
            for i_offset in range(0, resolution, patch_size):
                for j_offset in range(0, resolution, patch_size):
                    m = base[m_offset: m_offset + patch_size]
                    n = base[n_offset: n_offset + patch_size]
                    i = base[i_offset: i_offset + patch_size]
                    j = base[j_offset: j_offset + patch_size]
                    result[
                        :,
                        m_offset: m_offset + patch_size,
                        n_offset: n_offset + patch_size,
                        :,
                        i_offset: i_offset + patch_size,
                        j_offset: j_offset + patch_size
                    ] = torch.einsum(
                        "abcdef, defghi -> abcghi",
                        construct_lens(m, n, base2x, base2x, False),
                        construct_lens(base2x, base2x, i, j, True)
                    ).to(DTYPES[dtype]).cpu()
    return result.reshape(resolution * resolution * 2, resolution * resolution * 2)


def get_true_resolution(resolution, *args):
    if isinstance(resolution, int):
        return resolution
    elif resolution == "adaptive":
        size = max(args)
        return max(1 << int(np.ceil(np.log(size) / np.log(2))), min(DEVICE_CONFIGS.keys()))
    else:
        raise NotImplementedError


input_transforms = {}
kernel_transforms = {}
output_transforms = {}


class SonicConv2d(nn.Conv2d):
    def __init__(self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        padding_mode='zeros',
        device=None,
        dtype=None,
        sonic_resolution="adaptive"
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype
        )
        self.sonic_resolution = sonic_resolution
        self.disable_speedtest()

    @classmethod
    def from_conv2d(cls, layer: nn.Conv2d, sonic_resolution: str="adaptive"):
        model = cls(
            in_channels=layer.in_channels,
            out_channels=layer.out_channels,
            kernel_size=layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
            dilation=layer.dilation,
            groups=layer.groups,
            bias=layer.bias is not None,
            padding_mode=layer.padding_mode,
            device=layer.weight.device,
            dtype=layer.weight.dtype,
            sonic_resolution=sonic_resolution,
        )
        model.weight = layer.weight
        if layer.bias is not None:
            model.bias = layer.bias
        return model

    def enable_speedtest(self):
        self.speedtesting = True

    def disable_speedtest(self):
        self.speedtesting = False
        if hasattr(self, "test_intermediate_output"):
            delattr(self, "test_intermediate_output")

    def forward(self, x):
        if self.speedtesting:
            if not hasattr(self, "test_intermediate_output"):
                output_h = int(x.shape[-2]) + sum(self._reversed_padding_repeated_twice[2:4]) - self.dilation[0] * (self.kernel_size[0] - 1)
                output_w = int(x.shape[-1]) + sum(self._reversed_padding_repeated_twice[0:2]) - self.dilation[1] * (self.kernel_size[1] - 1)
                self.register_buffer(
                    "test_intermediate_output",
                    torch.empty(1, 1, 1, 1, 1, device=self.weight.device, dtype=self.weight.dtype).expand(
                        x.shape[0],
                        self.out_channels,
                        self.in_channels // self.groups,
                        (output_h - 1) // self.stride[0] + 1,
                        (output_w - 1) // self.stride[1] + 1))
                # self.register_buffer(
                #     "test_intermediate_output",
                #     torch.empty(
                #         x.shape[0],
                #         self.out_channels,
                #         self.in_channels // self.groups,
                #         (output_h - 1) // self.stride[0] + 1,
                #         (output_w - 1) // self.stride[1] + 1, device=self.weight.device, dtype=self.weight.dtype).contiguous())
            output = torch.sum(self.test_intermediate_output, dim=2)
            # output = torch.einsum("boihw -> bohw", self.test_intermediate_output)
        else:
            output_h = int(x.shape[-2]) + sum(self._reversed_padding_repeated_twice[2:4]) - self.dilation[0] * (self.kernel_size[0] - 1)
            output_w = int(x.shape[-1]) + sum(self._reversed_padding_repeated_twice[0:2]) - self.dilation[1] * (self.kernel_size[1] - 1)
            if self.padding_mode != "zeros":
                x = F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode)
            batch_size, _, input_h, input_w = list(map(int, x.shape))
            resolution = get_true_resolution(self.sonic_resolution, input_h, input_w)

            x = x.reshape(batch_size * self.in_channels, input_h * input_w)
            x = torch.matmul(x, input_transforms[(resolution, input_h, input_w)])
            x = x.reshape(batch_size, self.groups, 1, self.in_channels // self.groups, -1)
            x_re, x_im = torch.chunk(x, 2, dim=-1)

            kernel = self.weight.reshape(self.out_channels * self.in_channels // self.groups, self.kernel_size[0] * self.kernel_size[1])
            kernel = torch.matmul(kernel, kernel_transforms[(resolution, *self.kernel_size, *self.dilation)])
            kernel = kernel.reshape(1, self.groups, self.out_channels // self.groups, self.in_channels // self.groups, -1)
            k_re, k_im = torch.chunk(kernel, 2, dim=-1)

            z_re = torch.einsum("bgoic, bgoic -> bgoc", x_re, k_re) - torch.einsum("bgoic, bgoic -> bgoc", x_im, k_im)
            z_im = torch.einsum("bgoic, bgoic -> bgoc", x_re, k_im) + torch.einsum("bgoic, bgoic -> bgoc", k_re, x_im)
            z = torch.cat([z_re, z_im], dim=-1)

            output = torch.matmul(z, output_transforms[(resolution, output_h, output_w, *self.stride)])
            output = output.reshape(batch_size, self.out_channels, (output_h - 1) // self.stride[0] + 1, (output_w - 1) // self.stride[1] + 1)

        if self.bias is not None:
            output += self.bias.reshape(1, -1, 1, 1)
        return output


@no_train
def conv_to_sonic(
        model,
        input_shape,
        sonic_resolution="adaptive",
        dtype="fp16",
        device="cuda",
        cache_dir="sonic_transforms",
        cache_name=None,
        verbose=True,
    ):
    global input_transforms, kernel_transforms, output_transforms
    os.makedirs(cache_dir, exist_ok=True)
    sonic_transform_path = os.path.join(cache_dir, f"{cache_name}-{sonic_resolution}.pth")

    if cache_name is not None and os.path.isfile(sonic_transform_path):
        if verbose:
            print(f"Loading pre-computed sonic transforms from {sonic_transform_path} ...")
            sonic_transforms = torch.load(sonic_transform_path)
            input_transforms = sonic_transforms["input_transforms"]
            kernel_transforms = sonic_transforms["kernel_transforms"]
            output_transforms = sonic_transforms["output_transforms"]
    else:
        sonic_transform_tables = {}

        def get_sonic_transform_table(resolution, device="cuda"):
            if resolution in sonic_transform_tables:
                return sonic_transform_tables[resolution]
            else:
                sonic_transform_table_path = os.path.join(cache_dir, f"sonic-transform-table-{resolution}x{resolution}-{dtype}.pth")
                if os.path.isfile(sonic_transform_table_path):
                    if verbose:
                        print(f"Loading pre-computed sonic transform table {resolution}x{resolution} ({dtype}) from {sonic_transform_table_path} ...")
                    sonic_transform_tables[resolution] = torch.load(sonic_transform_table_path)
                    return sonic_transform_tables[resolution]
                else:
                    if verbose:
                        print(f"Constructing sonic transform table {resolution}x{resolution} ({dtype}) ...")
                    sonic_transform_tables[resolution] = construct_linear_transform(resolution, device, dtype)
                    if verbose:
                        print(f"Saving pre-computed sonic transform table {resolution}x{resolution} ({dtype}) to {sonic_transform_table_path} ...")
                    torch.save(sonic_transform_tables[resolution], sonic_transform_table_path)
                    return sonic_transform_tables[resolution]

        def update_transforms(layer, x):
            x = x[0]
            input_h, input_w = x.shape[-2:]
            if layer.padding_mode != "zeros":
                input_h += sum(layer._reversed_padding_repeated_twice[2:4])
                input_w += sum(layer._reversed_padding_repeated_twice[0:2])
            resolution = get_true_resolution(sonic_resolution, input_h, input_w)
            sonic_transform_table = get_sonic_transform_table(resolution)

            indices = get_indices((input_h, input_w), resolution=resolution)
            input_key = (resolution, input_h, input_w)
            if input_key not in input_transforms:
                input_transforms[input_key] = sonic_transform_table[indices].clone()
                input_transforms[input_key].requires_grad = False

            kernel_h, kernel_w = layer.kernel_size
            dh, dw = layer.dilation
            kernel_h, kernel_w = dh * (kernel_h - 1) + 1, dw * (kernel_w - 1) + 1
            indices = get_indices((kernel_h, kernel_w), (dh, dw), resolution=resolution)
            kernel_key = (resolution, *layer.kernel_size, *layer.dilation)
            if kernel_key not in kernel_transforms:
                kernel_transforms[kernel_key] = sonic_transform_table[indices].clone()
                kernel_transforms[kernel_key].requires_grad = False

            output_h = x.shape[-2] + sum(layer._reversed_padding_repeated_twice[2:4]) - kernel_h + 1
            output_w = x.shape[-1] + sum(layer._reversed_padding_repeated_twice[0:2]) - kernel_w + 1
            indices = get_indices((output_h, output_w), layer.stride, resolution=resolution, flipped=True)
            output_key = (resolution, output_h, output_w, *layer.stride)
            if output_key not in output_transforms:
                output_transforms[output_key] = sonic_transform_table[:, indices].clone()
                output_transforms[output_key].requires_grad = False

        def add_hook(model):
            for name, layer in model.named_children():
                if isinstance(layer, nn.Conv2d):
                    layer.register_forward_pre_hook(update_transforms)
                add_hook(layer)

        if verbose:
            print(f"Constructing sonic transforms {cache_name}-{sonic_resolution} ...")
        add_hook(model)
        _ = model(torch.rand(1, *input_shape).to(DTYPES[dtype]).to(device))
        del sonic_transform_tables

        if cache_name is not None:
            if verbose:
                print(f"Saving pre-computed sonic transforms to {sonic_transform_path} ...")
            torch.save(
                {
                    "input_transforms": input_transforms,
                    "kernel_transforms": kernel_transforms,
                    "output_transforms": output_transforms,
                },
                sonic_transform_path
            )

    for ts in [input_transforms, kernel_transforms, output_transforms]:
        for k in ts:
            ts[k] = ts[k].to(device)
    if verbose:
        print("Input transforms:", input_transforms.keys())
        print("Kernel transforms:", kernel_transforms.keys())
        print("Output transforms:", output_transforms.keys())

    def convert(model):
        for name, layer in model.named_children():
            if isinstance(layer, nn.Conv2d):
                setattr(model, name, SonicConv2d.from_conv2d(layer, sonic_resolution))
            convert(layer)
    convert(model)


def enable_speedtest(model):
    for name, layer in model.named_children():
        if isinstance(layer, SonicConv2d):
            layer.enable_speedtest()
        enable_speedtest(layer)


def disable_speedtest(model):
    for name, layer in model.named_children():
        if isinstance(layer, SonicConv2d):
            layer.disable_speedtest()
        disable_speedtest(layer)
