import math
import wandb
import os
import numpy as np

from spaghettini import quick_register
from torch import nn
import torch
import torch.nn.functional as F
from torch.nn import Linear, Bilinear
import matplotlib.pyplot as plt
import torchgeometry as tgm


from src.dl.models.spatial_transformer import transform_image


@quick_register
class FindPlusProver(nn.Module):
    def __init__(self, in_channels, hid_channels, hid_feats, proof_feats, aux_feats, kernel_size=3,
                 cat_position_embeddings=False, detach_aux_head=False, add_aux_hidden_layer=False,
                 add_first_layer_batch_norm=False):
        super().__init__()
        in_channels = in_channels if not cat_position_embeddings else in_channels + 2
        self.bn0 = nn.InstanceNorm2d(num_features=in_channels)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=hid_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.bn1 = nn.InstanceNorm2d(num_features=hid_channels)
        self.conv2 = nn.Conv2d(in_channels=hid_channels, out_channels=hid_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.bn2 = nn.InstanceNorm2d(num_features=hid_channels)
        self.conv3 = nn.Conv2d(in_channels=hid_channels, out_channels=hid_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.bn3 = nn.InstanceNorm2d(num_features=hid_channels)

        self.proof_linear = SmallInitLinear(in_features=hid_feats, out_features=proof_feats)
        self.cat_position_embeddings = cat_position_embeddings
        self.ln_proof = nn.LayerNorm([proof_feats])

        self.add_first_layer_batch_norm = add_first_layer_batch_norm

        # Auxiliary head.
        self.detach_aux_head = detach_aux_head
        self.add_aux_hidden_layer = add_aux_hidden_layer
        if not add_aux_hidden_layer:
            self.aux_head = nn.Sequential(*[SmallInitLinear(in_features=hid_feats, out_features=aux_feats)])
        else:
            self.aux_head = nn.Sequential(*[Linear(in_features=hid_feats, out_features=128),
                                            SmallInitLinear(in_features=128, out_features=aux_feats)])

        print("WARNING: DON'T HARDCODE AUX LINEAR HIDDEN FEATURES")

    def forward(self, xs, hints=None):
        bs, c, h, w = xs.shape

        # If asked, concatenate position embeddings.
        if self.cat_position_embeddings:
            h_coords = torch.linspace(start=-1., end=1., steps=h).view(1, 1, h, 1).repeat(bs, 1, 1, w)
            w_coords = torch.linspace(start=-1., end=1., steps=w).view(1, 1, 1, w).repeat(bs, 1, h, 1)
            xs = torch.cat((xs, h_coords, w_coords), dim=1)

        # Run through convolutional feature extractor.
        zs = xs
        if self.add_first_layer_batch_norm:
            zs = self.bn0(zs)
        zs = self.bn1(F.leaky_relu(self.conv1(xs)))
        zs = self.bn2(F.leaky_relu(self.conv2(zs)))
        zs = self.bn3(F.leaky_relu(self.conv3(zs)))

        # Apply average pooling.
        feats = torch.mean(zs, dim=1).view(bs, -1)

        # Proof head and auxiliary head.
        proofs = self.proof_linear(feats)
        if self.detach_aux_head:
            feats_copy = feats.detach()
            aux_head = self.aux_head(feats_copy)
        else:
            aux_head = self.aux_head(feats)

        return proofs, aux_head


@quick_register
class FindPlusBilinearVerifier(nn.Module):
    def __init__(self, in_features, proof_features, out_features):
        super().__init__()
        self.bilinear1 = Bilinear(in1_features=in_features, in2_features=proof_features, out_features=out_features)

    def forward(self, xs, proofs):
        bs = xs.shape[0]

        # Flatten.
        xs = xs.view(bs, -1)
        proofs = proofs.view(bs, -1)

        # Run through the network.
        return self.bilinear1(input1=xs, input2=proofs)


@quick_register
class FindPlusSTNVerifier(nn.Module):
    def __init__(self, in_features, proof_features, decision_features, out_features, num_heads=1, padding_mode="zeros",
                 shear=True, scaling=True, default_scale_factor=1., detach_aux_head=True, ):
        super().__init__()
        self.in_features = in_features
        self.proof_features = proof_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.padding_mode = padding_mode
        self.shear = shear
        self.scaling = scaling
        self.default_scale_factor = default_scale_factor
        self.detach_aux_head = detach_aux_head

        self.proof_preprocessor = Linear(in_features=proof_features, out_features=num_heads * 6)
        self.penultimate_linear = Linear(in_features=num_heads * in_features, out_features=decision_features)
        self.classifier = Linear(in_features=decision_features, out_features=out_features)
        self.aux_head = self._get_aux_head()

        # Warnings.
        print("WARNING: Currently the verifier aux head is hardcoded. ")

    def _get_aux_head(self):
        return nn.Sequential(*[Linear(in_features=6 * self.num_heads, out_features=128),
                               nn.LayerNorm((128,)),
                               nn.LeakyReLU(),
                               Linear(in_features=128, out_features=128),
                               nn.LayerNorm((128,)),
                               nn.LeakyReLU(),
                               Linear(in_features=128, out_features=2)])

    def forward(self, xs, proofs_list):
        assert isinstance(proofs_list, list) or isinstance(proofs_list, tuple)
        bs, c, h, w = xs.shape
        num_p = len(proofs_list)  # Number of proofs.

        # Expand proofs along the batch dimension, and duplicate xs to match the new batch dimension.
        proofs = torch.cat(proofs_list, dim=0)
        xs = torch.cat(len(proofs_list) * [xs])

        # Get the spatial transformer theta (similarity transform parameters).
        theta = self.proof_preprocessor(proofs).view(num_p * bs, self.num_heads, 6)

        # Important: I'll encourage the proof processor to output 1 for the scaling to speed up learning.
        theta = theta + torch.tensor([1., 0., 0., 0., 1., 0.]).type_as(xs)
        theta = theta.view(num_p * bs, self.num_heads, 2, 3)

        # Set some coordinates of theta to 0 if shear is disabled.
        if not self.shear:
            shear_mask = torch.ones_like(theta)
            shear_mask[:, :, 0, 1] = 0.
            shear_mask[:, :, 1, 0] = 0.
            theta = theta * shear_mask

        if not self.scaling:
            scaling_mask = torch.zeros_like(theta)
            scaling_mask[:, :, 0, 0] = 1.
            scaling_mask[:, :, 1, 1] = 1.
            scaling_component = theta * scaling_mask
            theta = (theta - scaling_component) + self.default_scale_factor * scaling_mask

        # ____ Transform the image. ____
        # Concatenate copies of the input image to be transformed.
        xs_repeated = xs.repeat(1, self.num_heads, 1, 1).view(num_p * bs * self.num_heads, 1, h, w)
        theta_folded = theta.view(num_p * bs * self.num_heads, 2, 3)

        # Transform the image.
        transformed_imgs = transform_image(theta_folded, xs_repeated,
                                           padding_mode=self.padding_mode).view(num_p * bs, self.num_heads, h, w)
        transformed_imgs = torch.split(transformed_imgs, self.num_heads * [1], dim=1)

        # Log five image - transformed image pairs.
        model_dict = {"xs": xs, "xs_hat": transformed_imgs, "thetas": theta}

        # Concatenate the images.
        transformed_imgs = [im.view(num_p * bs, -1) for im in transformed_imgs]
        xs_hat_cat = torch.cat(transformed_imgs, dim=1)

        # Extract decision vectors.
        decision_vectors = self.penultimate_linear(xs_hat_cat)

        # Unfold the batch dimension and perform pooling on the decision vectors.
        decision_vectors_unfolded = torch.cat(torch.split(decision_vectors.view(num_p * bs, 1, -1),
                                                          num_p * [bs], dim=0), dim=1)
        decision_vectors_pooled = torch.sum(decision_vectors_unfolded, dim=1)

        # Make final decision.
        decision = self.classifier(decision_vectors_pooled.view(bs, -1))

        # Run the auxiliary head.
        theta = theta.clone().detach() if self.detach_aux_head else theta
        ys_aux = self.aux_head(theta.view(num_p * bs, -1))
        ys_aux = torch.cat(torch.split(ys_aux.view(num_p * bs, 1, -1), num_p * [bs], dim=0), dim=1).squeeze()

        return decision, ys_aux, model_dict

    def log_transformed_images(self, logger, model_logs, ys_true, batch_nb, prepend_key):
        plotter = _plot_original_and_transformed_images(xs=model_logs["xs"], xs_hat=model_logs["xs_hat"],
                                                        thetas=model_logs["thetas"], ys=ys_true.clone().detach(),
                                                        title=f"")
        logger.experiment.log({f"{prepend_key}": plotter})
        plt.close("all")


class SmallInitLinear(Linear):
    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0., std=0.001)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


def _plot_original_and_transformed_images(xs, xs_hat, thetas, ys, title):
    assert isinstance(xs_hat, list) or isinstance(xs_hat, tuple)
    assert len(thetas.shape) == 4
    num_heads = thetas.shape[1]

    # Find which indices to plot.
    num_imgs = 10
    target_ys = (num_imgs // 2) * [0] + (num_imgs // 2) * [1]
    idxs = list()
    c = 0
    for i in range(ys.shape[0]):
        if float(ys[i]) == target_ys[c]:
            idxs.append(i)
            c += 1
        if c == len(target_ys):
            break

    xs = xs.clone().detach().cpu().numpy()
    thetas = np.round(thetas.clone().detach().cpu().numpy(), 2)

    fig, axs = plt.subplots(1 + num_heads, num_imgs)
    # fig.set_figheight(20)
    # fig.set_figwidth(8)
    for i, y_idx in enumerate(idxs):
        axs[0, i].imshow(X=xs[y_idx, 0], cmap="coolwarm")
        # axs[0, i].set_title(f"{title}")
        axs[0, i].axis('off')
        for j in range(num_heads):
            curr_xs_hat = xs_hat[j].clone().detach().cpu().numpy()
            axs[1 + j, i].imshow(X=curr_xs_hat[y_idx, 0], cmap="coolwarm")
            # axs[1 + j, i].set_title(f"{str(thetas[y_idx, j])}")
            axs[1 + j, i].axis('off')
    plt.axis('off')
    plt.tight_layout()

    return plt


class FindThePlusConvFeatExtractor(nn.Module):
    def __init__(self, in_channels, hid_channels, hid_feats, proof_feats, aux_feats, kernel_size=3,
                 cat_position_embeddings=False):
        super().__init__()


if __name__ == "__main__":
    """
    # Run from root. 
    python -m src.dl.models.task_specific.find_plus_nets
    """
    test_num = 3

    if test_num == 0:
        # Test prover classification.
        # Get random input.
        img = torch.randn(size=(1, 1, 10, 10))

        # Initialize network.
        net = FindPlusProver(in_channels=1, hid_channels=20, out_features=2, classify_features=100)
        print(net)

        # Run the network.
        out = net(img)
        breakpoint()

    if test_num == 1:
        # Test bilinear verifier.
        # Get random input.
        imgs = torch.randn(size=(1, 1, 10, 10))
        proofs_ = torch.randn(size=(1, 1, 10, 10))

        # Initialize network.
        net = FindPlusBilinearVerifier(in_features=100, proof_features=100, out_features=2)
        print(net)

        # Run the network.
        out = net(xs=imgs, proofs=proofs_)
        breakpoint()

    if test_num == 2:
        # Test Spacial Transformer verifier.
        # Get random input.
        imgs = torch.randn(size=(2, 1, 10, 10))
        proofs_ = torch.randn(size=(2, 2))

        # Initialize network.
        net = FindPlusSTNVerifier(in_features=100, proof_features=2, out_features=2)
        print(net)

        # Run the network.
        out = net(xs=imgs, proofs=proofs_)
        breakpoint()

    if test_num == 3:
        # Test if pre-constructed ResNet backbones can be used with GroupNorm.
        breakpoint()
        import torchvision.models as models
        from torch.nn import InstanceNorm2d
        resnet18 = models.resnet18(pretrained=False, norm_layer=InstanceNorm2d)
        xs = torch.randn(size=(256, 3, 12, 14))
        ys = resnet18(xs)

