import jax
import jax.numpy as jnp
import numpy as np
import torch

import architectures
import eval.trainers as trainers
import unrolled_canaries
import multimodel_train_state


class JaxTrainer(trainers.ModelTrainer):
    def __init__(
        self,
        learning_rate: float,
        momentum: float,
        num_epochs: int,
        batch_size: int,
        mlp_width: int,
        standardize: bool,
        dp_params: unrolled_canaries.DPParams | None,
    ):
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.standardize = standardize
        self.dp_params = dp_params

        self.architecture = architectures.MLP(width=mlp_width)

    def train(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
        seed: int,
        device: torch.device,
    ) -> tuple[multimodel_train_state.MultiModelTrainState, tuple[float, ...]]:
        key = jax.random.PRNGKey(seed)

        images, targets = self._prepare_data(images, targets)

        # Initialize model
        key, key_init = jax.random.split(key)
        image_shape_hwc = images.shape[1:]
        init_state = unrolled_canaries.create_train_state(
            (key_init,),
            learning_rate=self.learning_rate,
            momentum=self.momentum,
            num_models=1,
            arch=self.architecture,
            image_shape=image_shape_hwc,
            dp_params=self.dp_params,
        )
        del key_init

        # Build epoch permutations
        key, key_perms = jax.random.split(key)
        epoch_perms = unrolled_canaries.generate_model_perms(
            key=key_perms,
            num_models=1,
            num_epochs=self.num_epochs,
            num_samples=images.shape[0],
            batch_size=self.batch_size,
            sample_non_canaries=False,  # already done in given images (if required)
            canary_idx=None,
        )
        del key_perms

        # Train
        trained_state, _, _, train_accuracy = unrolled_canaries.train_model(
            state=init_state,
            epoch_perms=epoch_perms,
            train_images=images,
            train_targets=targets,
            test_images=None,
            test_targets=None,
            use_dp=self.dp_params is not None,
            verbose=False,
        )
        aux = (train_accuracy,)

        return trained_state, aux

    def predict(
        self,
        images: torch.Tensor,
        model: multimodel_train_state.MultiModelTrainState,
        aux: tuple[float, ...] | None,
    ) -> torch.Tensor:
        target_device = images.device
        images, _ = self._prepare_data(images, None)

        logits_per_model = unrolled_canaries.get_logits(model, images)
        logits = torch.from_numpy(np.array(logits_per_model["model_0"])).to(torch.float32)

        # Need to move logits to same device as images
        return logits.to(target_device)

    def _prepare_data(self, images: torch.Tensor, targets: torch.Tensor | None) -> tuple[jnp.ndarray, jnp.ndarray]:
        assert self.images_mean_std is not None

        # Move training data to jax
        # Code expects (N, H, W, C) in [0, 1], but images are torch (N, C, H, W)
        images = jnp.array(images.permute(0, 2, 3, 1))
        if targets is not None:
            targets = jnp.array(targets, dtype=jnp.int16)
        else:
            targets = None

        # Standardize with given statistics
        dataset_mean = jnp.array(self.images_mean_std[0], dtype=jnp.float32)
        dataset_std = jnp.array(self.images_mean_std[1], dtype=jnp.float32)
        if self.standardize:
            images = (images - dataset_mean.reshape(1, 1, 1, -1)) / dataset_std.reshape(1, 1, 1, -1)

        return images, targets
