import math

import pennylane as qml
import torch
import torch.nn as nn

from loss import FidLossDotProd, FidLossDotProdAAE
from models.batch_encoders import aae_encoder_for_train

from .superencoders import MLP, DeepMLP, mat_fn

SUPER_ENCODERS = {
    "MLP": MLP,
    "DeepMLP": DeepMLP
}


class AM_StateGenerator(torch.nn.Module):
    def __init__(self, config) -> None:
        self.n_qubits = config.state_generator.am_encoder.n_qubits
        amplitude_damping_prob = config.state_generator.am_encoder.AmplitudeDamping
        depolarizing_prob = config.state_generator.am_encoder.DepolarizingChannel
        @qml.qnode(
            qml.device(
                config.state_generator.am_encoder.q_device,
                wires=config.state_generator.am_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
            expansion_strategy="device"
        )
        @qml.simplify
        def am_encoder(inputs):
            assert inputs.shape[0] == 1
            qml.MottonenStatePreparation(inputs.reshape(-1), wires=range(self.n_qubits))
            return qml.state()


        @qml.qnode(
            qml.device(
                config.state_generator.am_encoder.q_device,
                wires=config.state_generator.am_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
            expansion_strategy="device"
        )
        @qml.transforms.insert(qml.AmplitudeDamping, amplitude_damping_prob, position="all")
        @qml.transforms.insert(qml.DepolarizingChannel, depolarizing_prob, position="all")
        @qml.simplify
        def am_encoder_noisy(inputs):
            assert inputs.shape[0] == 1
            qml.MottonenStatePreparation(inputs.reshape(-1), wires=range(self.n_qubits))
            return qml.state()
        
        super().__init__()
        self.is_noisy = config.get("state_generator").get("am_encoder").get("noisy", False)

        if self.is_noisy:
            self.am_encoder = am_encoder_noisy
        else:
            self.am_encoder = am_encoder

    def forward(self, target_state):
        return self.am_encoder(target_state)


class AG_StateGenerator(torch.nn.Module):
    pass


class AAE_StateGenerator(torch.nn.Module):
    def __init__(self, config) -> None:
        self.n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers
        self.n_qubits = config.state_generator.aae_encoder.n_qubits
        amplitude_damping_prob = config.state_generator.aae_encoder.get("AmplitudeDamping", 0.0)
        depolarizing_prob = config.state_generator.aae_encoder.get("DepolarizingChannel", 0.0)

        @qml.qnode(
            qml.device(
                config.state_generator.aae_encoder.q_device,
                wires=config.state_generator.aae_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
        )
        @qml.simplify
        def aae_encoder(inputs, weights):
            aae_encoder_for_train(
                weights,
                self.n_encoder_layers,
                self.n_qubits,
            )
            return qml.state()

        @qml.qnode(
            qml.device(
                config.state_generator.aae_encoder.q_device,
                wires=config.state_generator.aae_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
        )
        @qml.simplify
        @qml.transforms.insert(qml.AmplitudeDamping, amplitude_damping_prob, position="all")
        @qml.transforms.insert(qml.DepolarizingChannel, depolarizing_prob, position="all")
        def aae_encoder_noisy(inputs, weights):
            aae_encoder_for_train(
                weights,
                self.n_encoder_layers,
                self.n_qubits,
            )
            return qml.state()

        super().__init__()
        self.config = config
        self.is_noisy = config.get("state_generator").get("aae_encoder").get("noisy", False)

        weight_shapes = {"weights": (1, self.n_encoder_layers, self.n_qubits)}
        self.criterion = self.get_criterion(is_noisy=self.is_noisy)

        # FIXME: is matrix_fn still in use? maybe delete it
        if self.is_noisy:
            self.matrix_fn = mat_fn(aae_encoder_noisy)
            self.aae_encoder = qml.qnn.TorchLayer(
                aae_encoder_noisy, weight_shapes, init_method=nn.init.uniform_
            ).to(self.config.device)
        else:
            self.matrix_fn = mat_fn(aae_encoder)
            self.aae_encoder = qml.qnn.TorchLayer(
                aae_encoder, weight_shapes, init_method=nn.init.uniform_
            ).to(self.config.device)


    def forward(self, target_state, verbose=False):
        self.train_for_state(target_state, verbose=verbose)
        return self.compute_state()

    def compute_state(self):
        _ = torch.zeros((1,), dtype=torch.float32).to(
            device=self.config.device
        )  # inputs doesn't matter, But TorchLayer need it
        return self.aae_encoder(_)

    def get_criterion(self, is_noisy=False):
        if "DotProd" == self.config.state_generator.loss:
            criterion = FidLossDotProdAAE(is_noisy)
        elif "MSE" == self.config.state_generator.loss:
            criterion = nn.MSELoss()
        return criterion

    def get_optimizer(self):
        optimizer_cls = getattr(torch.optim, self.config.state_generator.optimizer.name)
        optimizer = optimizer_cls(
            self.aae_encoder.parameters(),
            **self.config.state_generator.optimizer.args,
        )

        return optimizer

    def train_for_state(self, target_state, verbose=False):
        # expecte target_state.shape == (N, )
        assert target_state.shape == (
            1,
            2**self.n_qubits,
        ), f"{target_state.shape} != {torch.Size((1, 2**self.n_qubits))}"
        n_step = self.config.state_generator.n_train_step
        target_state = target_state.to(device=self.config.device)
        _ = torch.zeros((1,), dtype=torch.float32).to(
            device=self.config.device
        )  # inputs doesn't matter, But TorchLayer need it

        optimizer = self.get_optimizer()
        # scheduler = CosineAnnealingLR(optimizer, T_max=n_step)

        for t in range(n_step):
            result_state = self.aae_encoder(_)
            loss = self.criterion(result_state, target_state)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            # scheduler.step()

            if verbose:
                if t % 10 == 0:
                    print(f"loss: {loss.item()}", end="\r")


class StateGenerator(torch.nn.Module):
    """
    1. A super_encoder generates predicted parameters
    2. Construct circuit unitary matrix based on the parameters
    3. Generate state vector
    """

    def __init__(self, config):
        n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers
        n_qubits = config.state_generator.aae_encoder.n_qubits
        
        if getattr(config.state_generator.aae_encoder, 'noisy', None) is not None:
            self.is_noisy = config.get("state_generator").get("aae_encoder").get("noisy", False)
            amplitude_damping_prob = config.state_generator.aae_encoder.AmplitudeDamping
            depolarizing_prob = config.state_generator.aae_encoder.DepolarizingChannel
        else:
            self.is_noisy = False
            amplitude_damping_prob = 0.0
            depolarizing_prob = 0.0
        
        # wrap around aae_encoder_for_train
        @qml.qnode(
            qml.device(
                config.state_generator.aae_encoder.q_device,
                wires=config.state_generator.aae_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
        )
        @qml.simplify
        def aae_encoder(inputs, weights):
            aae_encoder_for_train(
                inputs,
                n_encoder_layers,
                n_qubits,
            )
            return qml.state()

        # wrap around aae_encoder_for_train
        @qml.qnode(
            qml.device(
                config.state_generator.aae_encoder.q_device,
                wires=config.state_generator.aae_encoder.n_qubits,
            ),
            interface="torch",
            diff_method="backprop",
        )
        @qml.simplify
        @qml.transforms.insert(qml.AmplitudeDamping, amplitude_damping_prob, position="all")
        @qml.transforms.insert(qml.DepolarizingChannel, depolarizing_prob, position="all")
        def aae_encoder_noisy(inputs, weights):
            bsz = inputs.shape[0]
            aae_encoder_for_train(
                inputs,
                n_encoder_layers,
                n_qubits,
            )
            return qml.state()

        # No use, just a hack to make pennylane torchlayer work
        weight_shapes = {"weights": (n_encoder_layers, n_qubits)}

        super().__init__()
        self.config = config
        self.super_encoder = SUPER_ENCODERS[config.state_generator.super_encoder.arch](
            **config.state_generator.super_encoder
        )

        # FIXME: is matrix_fn still in use? maybe delete it
        if self.is_noisy:
            self.matrix_fn = mat_fn(aae_encoder_noisy)
            self.qc = qml.qnn.TorchLayer(aae_encoder_noisy, weight_shapes)
        else:
            self.matrix_fn = mat_fn(aae_encoder)
            self.qc = qml.qnn.TorchLayer(aae_encoder, weight_shapes)

    def forward(self, x, output_state: bool = False):
        params = self.super_encoder(x)

        if output_state:
            state = self.qc(params)
            return state.real, params
        return params

    def compute_state(self, x):
        params = self.super_encoder(x)
        return self.qc(params)

    def save(self, save_path):
        torch.save(self.state_dict(), save_path)

    def load(self, save_path, map_location=None, strict=True):
        self.load_state_dict(
            torch.load(save_path, map_location=map_location), strict=strict
        )
