import numpy as np

from spaghettini import quick_register
from torch import nn
import torch
import torch.nn.functional as F
from torch.nn import Linear
import matplotlib.pyplot as plt

from src.dl.models.spatial_transformer import transform_image


@quick_register
class STNProofInputProcessor(nn.Module):
    def __init__(self, proof_features, num_heads=1, padding_mode="zeros",
                 shear=True, scaling=True, default_scale_factor=1.):
        super().__init__()
        self.proof_features = proof_features
        self.num_heads = num_heads
        self.padding_mode = padding_mode
        self.shear = shear
        self.scaling = scaling
        self.default_scale_factor = default_scale_factor

        self.proof_preprocessor = Linear(in_features=proof_features, out_features=num_heads * 6)

    def forward(self, xs, proofs_list):
        proofs = proofs_list[0] if isinstance(proofs_list, list) or isinstance(proofs_list, tuple) else proofs_list
        bs, c, h, w = xs.shape

        # Get the spatial transformer theta (similarity transform parameters).
        theta = self.proof_preprocessor(proofs).view(bs, self.num_heads, 6)

        # Important: I'll encourage the proof processor to output 1 for the scaling to speed up learning.
        theta = theta + torch.tensor([1., 0., 0., 0., 1., 0.]).type_as(xs)
        theta = theta.view(bs, self.num_heads, 2, 3)

        # Set some coordinates of theta to 0 if shear is disabled.
        if not self.shear:
            shear_mask = torch.ones_like(theta)
            shear_mask[:, :, 0, 1] = 0.
            shear_mask[:, :, 1, 0] = 0.
            theta = theta * shear_mask

        if not self.scaling:
            scaling_mask = torch.zeros_like(theta)
            scaling_mask[:, :, 0, 0] = 1.
            scaling_mask[:, :, 1, 1] = 1.
            scaling_component = theta * scaling_mask
            theta = (theta - scaling_component) + self.default_scale_factor * scaling_mask

        # ____ Transform the image. ____
        # Concatenate copies of the input image to be transformed.
        xs_repeated = xs.repeat(1, self.num_heads, 1, 1).view(bs * self.num_heads, 1, h, w)
        theta_folded = theta.view(bs * self.num_heads, 2, 3)

        # Transform the image.
        transformed_imgs = transform_image(theta_folded, xs_repeated,
                                           padding_mode=self.padding_mode).view(bs, self.num_heads, h, w)
        transformed_imgs = torch.split(transformed_imgs, self.num_heads * [1], dim=1)

        # Log image - transformed image pairs.
        model_dict = {"xs": xs, "xs_hat": transformed_imgs, "thetas": theta}

        # Concatenate the images and return.
        transformed_imgs = [im.view(bs, -1) for im in transformed_imgs]
        xs_hat_cat = torch.cat(transformed_imgs, dim=1)

        return xs_hat_cat, model_dict

    def log_transformed_images(self, logger, model_logs, ys_true, batch_nb, prepend_key):
        plotter = _plot_original_and_transformed_images(xs=model_logs["xs"], xs_hat=model_logs["xs_hat"],
                                                        thetas=model_logs["thetas"], ys=ys_true.clone().detach(),
                                                        title=f"")
        logger.experiment.log({f"{prepend_key}": plotter})
        plt.close("all")


def _plot_original_and_transformed_images(xs, xs_hat, thetas, ys, title):
    assert isinstance(xs_hat, list) or isinstance(xs_hat, tuple)
    assert len(thetas.shape) == 4
    num_heads = thetas.shape[1]

    # Find which indices to plot.
    num_imgs = 10
    target_ys = (num_imgs // 2) * [0] + (num_imgs // 2) * [1]
    idxs = list()
    c = 0
    for i in range(ys.shape[0]):
        if float(ys[i]) == target_ys[c]:
            idxs.append(i)
            c += 1
        if c == len(target_ys):
            break

    xs = xs.clone().detach().cpu().numpy()
    thetas = np.round(thetas.clone().detach().cpu().numpy(), 2)

    fig, axs = plt.subplots(1 + num_heads, num_imgs)
    # fig.set_figheight(20)
    # fig.set_figwidth(8)
    for i, y_idx in enumerate(idxs):
        axs[0, i].imshow(X=xs[y_idx, 0], cmap="coolwarm")
        # axs[0, i].set_title(f"{title}")
        axs[0, i].axis('off')
        for j in range(num_heads):
            curr_xs_hat = xs_hat[j].clone().detach().cpu().numpy()
            axs[1 + j, i].imshow(X=curr_xs_hat[y_idx, 0], cmap="coolwarm")
            # axs[1 + j, i].set_title(f"{str(thetas[y_idx, j])}")
            axs[1 + j, i].axis('off')
    plt.axis('off')
    plt.tight_layout()

    return plt


@quick_register
class FindPlusConvFeatExtractor(nn.Module):
    def __init__(self, in_channels, hid_channels, out_channels, kernel_size=3, cat_position_embeddings=False):
        super().__init__()

        in_channels = in_channels if not cat_position_embeddings else in_channels + 2
        self.norm0 = nn.InstanceNorm2d(num_features=in_channels)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=hid_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.norm1 = nn.InstanceNorm2d(num_features=hid_channels)
        self.conv2 = nn.Conv2d(in_channels=hid_channels, out_channels=hid_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.norm2 = nn.InstanceNorm2d(num_features=hid_channels)
        self.conv3 = nn.Conv2d(in_channels=hid_channels, out_channels=out_channels, kernel_size=kernel_size,
                               padding=kernel_size // 2)
        self.norm3 = nn.InstanceNorm2d(num_features=hid_channels)

        self.cat_position_embeddings = cat_position_embeddings

    def forward(self, xs):
        bs, c, h, w = xs.shape

        # If asked, concatenate position embeddings.
        if self.cat_position_embeddings:
            h_coords = torch.linspace(start=-1., end=1., steps=h).view(1, 1, h, 1).repeat(bs, 1, 1, w)
            w_coords = torch.linspace(start=-1., end=1., steps=w).view(1, 1, 1, w).repeat(bs, 1, h, 1)
            xs = torch.cat((xs, h_coords, w_coords), dim=1)

        # Run through convolutional feature extractor.
        zs = xs
        zs = self.norm1(F.leaky_relu(self.conv1(zs)))
        zs = self.norm2(F.leaky_relu(self.conv2(zs)))
        zs = self.norm3(F.leaky_relu(self.conv3(zs)))

        # Apply average pooling and return.
        return torch.mean(zs, dim=1).view(bs, -1)


@quick_register
class FindPlusProofGenerator(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, feats):
        prover_outs = self.module(feats)
        model_dict = dict(prover_outs=prover_outs)

        return prover_outs, model_dict


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.task_specific.find_plus_nets
    """
    test_num = 0

    if test_num == 0:
        from torch.optim import Adam
        from torch.nn.functional import mse_loss
        # Optimize find plus proof generator.
        module = nn.Linear(in_features=1, out_features=1)

        pg = FindPlusProofGenerator(module=module)

        xs = torch.arange(1000)[None, 1].float()
        ys = torch.arange(1000)[None, 1].float()

        optim = Adam(pg.parameters(), lr=0.03)
        fn = lambda x: x

        for i in range(100):
            optim.zero_grad()
            outs, _ = pg(xs)
            outs = fn(outs)
            err = mse_loss(outs, ys)
            err.backward()
            optim.step()
            print(float(err))


