import math

import pennylane as qml
import torch
import torch.nn as nn
import torch.nn.functional as F


def mat_fn(qc):
    return qml.matrix(qc)

class FFW(nn.Module):
    def __init__(self, hidden_size, activation_fn) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(hidden_size, hidden_size*4),
            activation_fn,
            nn.Linear(hidden_size*4, hidden_size),
        )

    def forward(self, x):
        return self.layers(x)

class DeepMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_size=256, num_blocks=3, **kwargs):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_blocks = num_blocks
        self.activations = nn.Tanh()
        
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, hidden_size))
        for i in range(num_blocks):
            self.layers.append(FFW(hidden_size, self.activations))
        self.layers.append(nn.Linear(hidden_size, out_dim))
        self.layers.append(self.activations)


    def forward(self, x):
        for layer in self.layers:
            if isinstance(layer, FFW):
                x = layer(x) + x  # residual connection
            else:
                x = layer(x)
        x = math.pi * x
        return x


class EncoderDecoderRNN(torch.nn.Module):
    """
    Encode the initial data E with a MLP, to be used as the initial hidden state of RNN decoder,
    RNNCell output AAE params layer by layer, last output is fed as next input

    v0.0.1: input_ are always 0.0
    v0.0.2: input_ are last output
    v0.0.3: v0.0.1 + use GRU
    v0.0.4: v0.0.1 + pi*tanh before output
    v0.0.5: v0.0.4 + use last output as next input
    v0.0.6: v0.0.1 + hidden_size=256 (64 in v0.0.1)

    """

    def __init__(self, in_dim, hidden_size, num_aae_layers) -> None:
        """
        Args:
            in_dim: dim of input data
            hidden_size: hidden size of rnn
            num_aae_layer: need this to know how many step rnn need to inference

        """
        super().__init__()

        self.num_qubits = math.ceil(math.log2(in_dim))
        self.num_aae_layers = num_aae_layers

        self.activation = nn.LeakyReLU()

        # simple feed forward
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            self.activation,
            nn.Linear(in_dim, hidden_size),
        )

        # RNN as decoder
        self.decoder = nn.RNNCell(input_size=self.num_qubits, hidden_size=hidden_size)
        self.output_projection = nn.Linear(hidden_size, self.num_qubits)

    def forward(self, x):
        bsz = x.shape[0]

        h = self.encoder(x)
        input_ = torch.zeros(bsz, self.num_qubits)

        pred_params = []
        for t in range(self.num_aae_layers):
            h = self.decoder(input_, h)
            output = self.output_projection(h)
            # output = torch.pi*F.tanh(self.output_projection(h))

            # input_ = output

            pred_params.append(output)

        pred_params = torch.cat(pred_params, dim=1)

        return pred_params


# TODO: Test
class ManualChainMLP(torch.nn.Module):
    """Chain of LayerMLP to predict all parameters"""

    # TODO: in_dim and num_layers may not be necessary
    def __init__(self, in_dim, quantum_circuit, num_layers, is_uniform: bool = True):
        """

        Args:
            in_dim: Input dimension (2**num_qubits)
            quantum_circuit: pennylane quantum node (single layer AAE)
            num_layers: number of layers
            is_uniform: Whether to use one model for all layers
        """
        super().__init__()
        self.num_layers = num_layers
        self.layer = LayerMLP(in_dim)
        if not is_uniform:
            layers = []
            for _ in range(self.num_layers):
                layers.append(LayerMLP(in_dim))
            self.layer = torch.nn.ModuleList(layers)
        self.matrix_fn = mat_fn(quantum_circuit)

    def forward(self, x):
        """Use torch.matmul"""
        out = None
        in_state = x
        for i in range(self.num_layers):
            if isinstance(self.layer, torch.nn.ModuleList):
                pred_params = self.layer[i](in_state)
            else:
                pred_params = self.layer(in_state)
            out = torch.cat((out, pred_params), 1) if out is not None else pred_params
            mat = self.matrix_fn(weights=pred_params)
            in_state = torch.bmm(mat, in_state.unsqueeze(-1).to(torch.complex64))
            in_state = in_state.squeeze(-1).abs()
        return out

    def save(self, save_path):
        """Only save layer model"""
        torch.save(self.layer, save_path)

    def load(self, save_path):
        """Only load layer model"""
        self.layer = torch.load(save_path)


# TODO: Test
class AutoChainMLP(torch.nn.Module):
    """Chain of LayerMLP to predict all parameters"""

    # TODO: in_dim and num_layers may not be necessary
    def __init__(self, in_dim, quantum_circuit, num_layers):
        """

        Args:
            in_dim: Input dimension (2**num_qubits)
            quantum_circuit: pennylane quantum node (single layer AAE)
            num_layers: number of layers
        """
        super().__init__()
        self.num_layers = num_layers
        self.num_qubits = math.ceil(math.log2(in_dim))
        self.quantum_circuit = quantum_circuit
        self.layer_mlp = LayerMLP(in_dim)
        self.qlayer = qml.qnn.TorchLayer(
            quantum_circuit, {"weights": (self.num_layers, self.num_qubits)}
        )

    def forward(self, x):
        """Use pennylane qnode"""
        out = None
        in_state = x
        for _ in range(self.num_layers):
            pred_params = self.layer_mlp(in_state)
            # FIXME: pennylane TorchLayer.forward only accept on input argument
            in_state = self.qlayer(in_state, pred_params)
            out = torch.cat((out, pred_params), 1) if out is not None else pred_params
        return x


class LayerMLP(torch.nn.Module):
    """Use target state to predict the last layer's angles, see issue #4"""

    def __init__(self, in_dim):
        super().__init__()
        out_dim = math.ceil(math.log2(in_dim))
        self.layer1 = torch.nn.Linear(in_dim, 4 * in_dim)
        self.layer2 = torch.nn.Linear(4 * in_dim, in_dim**2)
        self.layer3 = torch.nn.Linear(in_dim**2, in_dim)
        self.layer4 = torch.nn.Linear(in_dim, out_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.relu(self.layer3(x))
        x = math.pi * torch.tanh(self.layer4(x))
        return x


class MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.activations = nn.Tanh()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            # nn.Linear(in_dim, in_dim*16),
            self.activations,
            nn.Linear(512, out_dim),
            # nn.Linear(in_dim*16, out_dim),
            self.activations,
        )

        # self.init_weights()

    def forward(self, x):
        x = math.pi * self.layers(x)
        return x

    def init_weights(self):

        for name, param in self.named_parameters():
            if name in ["weight"]:
                nn.init.xavier_normal_(param)
                    

        

class ConvEncDec(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 1), padding=(1, 2))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.deconv1 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.deconv2 = torch.nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = F.relu(self.deconv1(x))
        x = torch.sigmoid(self.deconv2(x))
        return x
