import warnings

import torch

import eval.trainers as trainers
import influence.torch_architectures


class TorchTrainer(trainers.ModelTrainer[None, torch.nn.Module]):
    def __init__(
        self,
        architecture: influence.torch_architectures.TorchArchitecture,
    ):
        # TODO: test
        # torch.backends.cudnn.benchmark = True

        self.architecture = architecture

    def train(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
        seed: int,
        device: torch.device,
    ) -> tuple[torch.nn.Module, None]:
        assert self.images_mean_std is not None
        with torch.random.fork_rng(devices=[device]):
            torch.manual_seed(seed)
            return self._train(images, targets, device)

    def _train(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
        device: torch.device,
    ) -> tuple[torch.nn.Module, None]:
        dataset_mean, dataset_std = tuple(map(lambda x: x.to(device), self.images_mean_std))

        # Prepare training data
        # Move to GPU
        images = images.cuda()
        targets = targets.cuda()

        crop_size = images.shape[2]

        # Standardize images
        images = (images - dataset_mean.view(1, -1, 1, 1)) / dataset_std.view(1, -1, 1, 1)

        # Convert dataset to target precision now for the rest of the process....
        images = images.to(dtype=self.architecture.model_dtype).requires_grad_(False)

        # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)
        if self.architecture.one_hot_targets:
            targets = torch.nn.functional.one_hot(targets).to(dtype=self.architecture.model_dtype)

        if self.architecture.pad_amount > 0:
            warnings.warn('Using padding. If not wanted, set `hyp["net"]["pad_amount"] = 0`')
            ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
            images = self._pad_images(images, self.architecture.pad_amount)

        # Actually build and train model
        model = self.architecture.build_model(images, device)
        model = self.architecture.train(
            images,
            targets,
            crop_size=crop_size,
            net=model,
            drop_last_batch=True,  # drop last batch to match JAX trainer
            generator=None,
        )
        return model, None

    @classmethod
    def _pad_images(cls, image: torch.Tensor, pad_amount: int) -> torch.Tensor:
        return torch.nn.functional.pad(
            image,
            (pad_amount,) * 4,
            "reflect",
        )

    def predict(self, images: torch.Tensor, model: torch.nn.Module, aux: None) -> torch.Tensor:
        # Standardize images
        assert self.images_mean_std is not None
        dataset_mean, dataset_std = tuple(map(lambda x: x.to(device=images.device), self.images_mean_std))
        images = (images - dataset_mean.view(1, -1, 1, 1)) / dataset_std.view(1, -1, 1, 1)

        images = images.to(dtype=self.architecture.model_dtype)

        # Perform inference
        eval_batchsize = 2500
        data_loader = torch.utils.data.DataLoader(images, batch_size=eval_batchsize, shuffle=False, drop_last=False)
        raw_predictions = []
        with torch.no_grad():
            for batch_images in data_loader:
                outputs = model(batch_images).to(torch.float32)
                raw_predictions.append(outputs)

        # Return the predicted class
        return torch.cat(raw_predictions, dim=0)
