import logging
import pathlib
import time
import typing
import warnings

import torch
from torch.utils.data import DataLoader, IterableDataset

from torch.utils.tensorboard import SummaryWriter

import eval.optimizers as optimizers
from influence.influence_function import InfluenceFunctionComputer
from influence.task_fast import ClassificationTask
from influence.torch_architectures import TorchArchitecture
from influence.utils import clear_gpu_cache


class InfluenceStochasticOptimizer(optimizers.CanaryOptimizer):
    def __init__(
        self,
        num_optimization_steps: int,
        canary_lr: float,
        canary_momentum: float,
        clamp: bool,
        num_jacobian_samples: int | None,
        architecture: TorchArchitecture,
        early_stopping_metric: typing.Literal["loss", "hinge_score", "logit_score"] | None,
        num_eval_models: int,
    ):
        self.num_optimization_steps = num_optimization_steps
        self.canary_lr = canary_lr
        self.canary_momentum = canary_momentum
        self.clamp = clamp
        self.num_jacobian_samples = num_jacobian_samples
        self.architecture = architecture
        self.early_stopping_metric = early_stopping_metric
        self.num_eval_models = num_eval_models

        self._log = logging.getLogger(self.__class__.__name__)
        # TODO: Just for debugging !!!
        self._log.setLevel(logging.DEBUG)

    def optimize(
        self,
        target: int,
        seed: int,
        canary_idx: int,
        intermediate_log_dir: pathlib.Path,
    ) -> tuple[torch.Tensor, int]:
        tensorboard_dir = intermediate_log_dir / "tensorboard"
        tensorboard_dir.mkdir(parents=True, exist_ok=True)
        writer = SummaryWriter(log_dir=tensorboard_dir)

        intermediate_canary_dir = intermediate_log_dir / "intermediate_canaries"
        intermediate_canary_dir.mkdir(parents=True, exist_ok=True)

        # Setup random generator
        # generator = torch.Generator(DEVICE)
        generator = None

        # All images (including canary) are always standardized
        # We keep two sets of images: with and without padding
        # The canary is injected into both sets every iteration, with padding applied dynamically
        # I.e., canary_image is standardized but not padded

        # Collect the initial canary example from the dataset
        canary_image = self._train_images_original[canary_idx].clone()
        # Keep the image being optimized over and optimizer state in high precision to avoid numerical issues
        canary_image = canary_image.to(dtype=torch.float64)

        # # TODO: Just for debugging!!! mislabel
        # self._train_targets_original[canary_idx] = 9 - self._train_targets_original[canary_idx]
        # if self.architecture.one_hot_targets:
        #     self._train_targets[canary_idx] = torch.nn.functional.one_hot(self._train_targets_original[canary_idx], num_classes=10)
        # else:
        #     self._train_targets[canary_idx] = self._train_targets_original[canary_idx]

        canary_label = self._train_targets_original[canary_idx]
        assert canary_label.ndim == 0
        canary_label_train = self._train_targets[canary_idx]
        assert canary_label_train.ndim == 0 if not self.architecture.one_hot_targets else canary_label_train.ndim == 1
        if canary_label.item() != target:
            warnings.warn(
                f"Canary label in dataset {canary_label.item()} does not match given target {target}; will use the original label {canary_label.item()}.",
            )
        del target

        def standardize(image: torch.Tensor) -> torch.Tensor:
            return (image - self._dataset_mean_std[0].view(-1, 1, 1)) / self._dataset_mean_std[1].view(-1, 1, 1)

        def unstandardize(image: torch.Tensor) -> torch.Tensor:
            return image * self._dataset_mean_std[1].view(-1, 1, 1) + self._dataset_mean_std[0].view(-1, 1, 1)

        writer.add_image("initial_canary_image", unstandardize(canary_image), 0)

        class MyIterableDataset(IterableDataset):
            def __init__(self, canary_image, canary_label):
                self.canary_image = canary_image
                self.canary_label = canary_label

            def canary_iterator(self):
                yield self.canary_image, self.canary_label

            def __iter__(self):
                return self.canary_iterator()  # creates a new generator each time

        device = self._train_images.device

        pad_amount = self.architecture.pad_amount

        # Train images are (potentially) padded
        train_images = self._train_images.clone()
        train_images_unpadded = self._train_images_original.clone()
        # Train targets are one-hot encoded; same for model training and curvature estimation
        train_targets = self._train_targets.clone()
        # Inject label; image will be updated each iteration
        train_targets[canary_idx] = canary_label_train

        test_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(self._test_images, self._test_targets_original),
            batch_size=512,
            shuffle=False,
            drop_last=False,
        )

        # Crop size is from original unpadded images
        crop_size = train_images_unpadded.shape[2]

        # Model without canary for reference
        assert self.early_stopping_metric is None or self.num_eval_models > 0
        self._log.debug("Training %d out models for evaluation", self.num_eval_models)
        out_models = tuple(
            self._train_model_with_canary(
                train_images,
                train_targets,
                canary_image=self._train_images_original[canary_idx],  # use original image; padding will be applied
                canary_idx=canary_idx,
                pad_amount=pad_amount,
                crop_size=crop_size,
                device=device,
                generator=generator,
            )
            for _ in range(self.num_eval_models)
        )

        # Train OUT model for influence calculation
        # TODO: Use one fixed OUT model, or retrain every time?
        self._log.debug("Training OUT model for influence calculation")
        model_out = self._train_model_with_canary(
            train_images,
            train_targets,
            canary_image=self._train_images_original[canary_idx],  # use original image; padding will be applied
            canary_idx=canary_idx,
            pad_amount=pad_amount,
            crop_size=crop_size,
            device=device,
            generator=generator,
        )
        model_out.eval()

        task = ClassificationTask(
            influence_modules=self.architecture.influence_modules(),
            representation_module=self.architecture.representation_module(),
            device=device,
        )

        best_canary_image = None
        best_canary_metric = -float("inf")
        best_canary_step = None

        # Init momentum
        grad_momentum = None

        for step_idx in range(self.num_optimization_steps):
            # Step 1: Train model
            self._log.debug("Training the model with the %d iteration of the canary...", step_idx)
            start_time = time.time()

            model_in = self._train_model_with_canary(
                train_images, train_targets, canary_image, canary_idx, pad_amount, crop_size, device, generator
            )
            model_in.eval()

            # Evaluate model on test set; this is for the PREVIOUS canary
            test_preds = []
            with torch.no_grad():
                for batch_images, _ in test_loader:
                    test_preds.append(model_in(batch_images))
            test_preds = torch.cat(test_preds)
            test_acc = (test_preds.argmax(dim=-1) == self._test_targets_original).float().mean()
            test_loss = torch.nn.functional.cross_entropy(test_preds, self._test_targets_original)
            writer.add_scalar("test_acc", test_acc, step_idx)
            writer.add_scalar("test_loss", test_loss, step_idx)

            # Step 2: Calculate New Canary Gradient
            # 2.1: Calculate IN model gradient
            model_in.requires_grad_(True)
            ekfac_in = InfluenceFunctionComputer(
                model=model_in,
                task=task,
                n_epoch=1,
                force_half_precision=(self.architecture.model_dtype == torch.float16),
            )
            train_images_unpadded[canary_idx] = canary_image.to(dtype=self.architecture.model_dtype)
            curvature_loader = torch.utils.data.DataLoader(
                torch.utils.data.TensorDataset(train_images_unpadded, train_targets),
                batch_size=512,  # TODO: how much can we afford?
                shuffle=False,
                drop_last=False,
            )
            ekfac_in.build_curvature_blocks(loader=curvature_loader)
            # Important: the label here needs to be the integer label, not the one-hot encoded label!
            canary_grad_in = ekfac_in.compute_self_scores_with_loader_double_jac(
                loader=DataLoader(
                    MyIterableDataset(canary_image.to(self.architecture.model_dtype), canary_label),
                    batch_size=1,
                ),
                use_measurement=True,
                num_samples=self.num_jacobian_samples,
            ).to(dtype=torch.float64)
            del ekfac_in

            # 2.2: Calculate OUT model gradient
            # This is just on the OUT model measurement directly
            model_out.zero_grad()
            canary_image.requires_grad_(True)

            # Upcasting the logits to float64 directly yields a float64 gradient
            logits_out = model_out(canary_image.unsqueeze(0).to(self.architecture.model_dtype)).to(torch.float64)
            labels_out = canary_label.unsqueeze(0)
            assert logits_out.ndim == 2 and labels_out.ndim == 1
            logits_out_target = logits_out[torch.arange(logits_out.shape[0]), ..., labels_out]
            logits_out_except_target = torch.scatter(
                logits_out, dim=1, index=labels_out.unsqueeze(0), value=float("-inf"),
            )
            hinge_score_out = logits_out_target - torch.max(logits_out_except_target, dim=-1).values
            hinge_score_out.sum().backward()
            assert canary_image.grad is not None and canary_image.grad.dtype == torch.float64
            canary_grad_out = canary_image.grad.clone()
            del canary_image.grad

            # Step 3: Update the canary
            canary_grad = canary_grad_in - canary_grad_out  # maximize IN, minimize OUT
            assert canary_grad.ndim == 4 and canary_grad.shape[0] == 1
            assert canary_grad.dtype == torch.float64
            writer.add_scalar("canary_grad_norm", torch.linalg.norm(canary_grad).item(), step_idx + 1)

            # Apply momentum (torch/optax-style)
            with torch.no_grad():
                if grad_momentum is None:
                    grad_momentum = canary_grad[0].clone().detach()
                else:
                    grad_momentum = canary_grad[0].clone().detach() + self.canary_momentum * grad_momentum
                canary_image = (canary_image + self.canary_lr * grad_momentum)

            if self.clamp:
                canary_image = standardize(torch.clamp(unstandardize(canary_image), 0, 1))

            writer.add_image("canary_image", unstandardize(canary_image), step_idx + 1)

            # Save intermediate canary to disk
            torch.save(
                (unstandardize(canary_image).cpu(), canary_label.cpu()),
                intermediate_canary_dir / f"canary_{step_idx + 1}.pt",
            )

            # Step 4: Evaluate canary performance and do early stopping
            if self.num_eval_models == 0:
                assert self.early_stopping_metric is None
                # No need to clone, as this will output the final canary anyway
                best_canary_image = canary_image
            else:
                # Train IN models
                in_models = tuple(
                    self._train_model_with_canary(
                        train_images,
                        train_targets,
                        canary_image=canary_image,
                        canary_idx=canary_idx,
                        pad_amount=pad_amount,
                        crop_size=crop_size,
                        device=device,
                        generator=generator,
                    )
                    for _ in range(self.num_eval_models)
                )
                assert (
                    len(in_models) == len(out_models)
                    and len(in_models) == self.num_eval_models
                    and self.num_eval_models > 0
                )
                metrics_out = self._get_metrics(out_models, canary_image, canary_label)
                metrics_in = self._get_metrics(in_models, canary_image, canary_label)
                del in_models

                assert metrics_in.keys() == metrics_out.keys()
                for key in metrics_in.keys():
                    writer.add_scalar(f"{key}_in", metrics_in[key].mean().item(), step_idx + 1)
                    writer.add_scalar(f"{key}_out", metrics_out[key].mean().item(), step_idx + 1)
                    writer.add_scalar(f"{key}_diff", (metrics_in[key] - metrics_out[key]).mean().item(), step_idx + 1)
                    writer.add_histogram(f"full_{key}_in", metrics_in[key].cpu(), step_idx + 1)
                    writer.add_histogram(f"full_{key}_out", metrics_out[key].cpu(), step_idx + 1)
                    writer.add_histogram(f"full_{key}_diff", (metrics_in[key] - metrics_out[key]).cpu(), step_idx + 1)

                # Early stopping
                if self.early_stopping_metric is not None:
                    if (
                        current_metric := (
                            metrics_in[self.early_stopping_metric] - metrics_out[self.early_stopping_metric]
                        )
                        .abs()
                        .mean()
                        .item()
                    ) > best_canary_metric:
                        best_canary_image = canary_image.clone()
                        best_canary_metric = current_metric
                        best_canary_step = step_idx + 1
                else:
                    # No need to clone, as this will output the final canary anyway
                    best_canary_image = canary_image

            del model_in
            clear_gpu_cache()

            self._log.debug("Took %s seconds.", time.time() - start_time)

        writer.flush()
        writer.close()

        if self.early_stopping_metric is not None:
            self._log.info(
                "Early stopping: found best canary at step %d with abs %s gap of %f.",
                best_canary_step,
                self.early_stopping_metric,
                best_canary_metric,
            )

        # Convert final canary image into correct format
        result_image = unstandardize(best_canary_image).to(torch.float32).cpu()
        result_image = torch.clamp(result_image, 0, 1)
        result_target = canary_label.item()
        return result_image, result_target

    def _train_model_with_canary(
        self,
        train_images: torch.Tensor,
        train_targets: torch.Tensor,
        canary_image: torch.Tensor,
        canary_idx: int,
        pad_amount: int,
        crop_size: int,
        device: torch.device,
        generator: torch.Generator | None,
    ) -> torch.nn.Module:
        # Inject canary into the training sets
        canary_image = canary_image.to(dtype=self.architecture.model_dtype)
        if pad_amount > 0:
            canary_image_padded = self._pad_images(canary_image, pad_amount)
        else:
            canary_image_padded = canary_image
        assert canary_image_padded.shape == train_images[canary_idx].shape
        train_images[canary_idx] = canary_image_padded
        # No need to touch label, as that does not change

        model = self.architecture.build_model(train_images, device)
        model = self.architecture.train(
            train_images,
            train_targets,
            crop_size=crop_size,
            net=model,
            generator=generator,
            drop_last_batch=False,
        )
        return model

    def _get_metrics(
        self, models: tuple[torch.nn.Module, ...], canary_image: torch.Tensor, canary_label: torch.Tensor
    ) -> dict[str, torch.Tensor]:
        assert canary_label.ndim == 0
        losses = []
        hinge_scores = []
        logit_scores = []
        for model in models:
            # Downcast canary to architecture precision but then upcast logits for maximum precision
            canary_logits = model(canary_image.unsqueeze(0).to(self.architecture.model_dtype)).to(torch.float64)

            losses.append(torch.nn.functional.cross_entropy(canary_logits, canary_label.unsqueeze(0)))

            # Hinge and logit scores
            score_logits = torch.clone(canary_logits)
            canary_target_logit = canary_logits[0, canary_label.item()]
            score_logits[0, canary_label.item()] = -float("inf")
            hinge_scores.append(canary_target_logit - torch.max(score_logits))
            logit_scores.append(canary_target_logit - torch.logsumexp(score_logits[0], 0))

        loss = torch.stack(losses)
        hinge_score = torch.stack(hinge_scores)
        logit_score = torch.stack(logit_scores)

        return {
            "loss": loss,
            "hinge_score": hinge_score,
            "logit_score": logit_score,
        }

    def prepare_data(
        self,
        train_images: torch.Tensor,
        train_labels: torch.Tensor,
        test_images: torch.Tensor,
        test_labels: torch.Tensor,
        dataset_mean_std: tuple[torch.Tensor, torch.Tensor],
        max_test_samples: None,
        test_sample_seed: None,    
    ) -> None:
        # Move to GPU
        train_images = train_images.cuda()
        train_labels = train_labels.cuda()
        test_images = test_images.cuda()
        test_labels = test_labels.cuda()

        self._dataset_mean_std = tuple(map(lambda x: x.to(device=train_images.device), dataset_mean_std))

        def standardize(input_images):
            return (input_images - self._dataset_mean_std[0].view(1, -1, 1, 1)) / self._dataset_mean_std[1].view(
                1, -1, 1, 1
            )

        # Standardize images
        train_images = standardize(train_images)
        test_images = standardize(test_images)

        # Convert dataset to target precision now for the rest of the process....
        train_images = train_images.to(dtype=self.architecture.model_dtype).requires_grad_(False)
        test_images = test_images.to(dtype=self.architecture.model_dtype).requires_grad_(False)

        # Store original images in case padding will be applied
        self._train_images_original = torch.clone(train_images)
        self._train_images = train_images
        self._test_images = test_images

        # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)
        self._train_targets_original = torch.clone(train_labels)
        self._test_targets_original = test_labels
        if self.architecture.one_hot_targets:
            self._train_targets = torch.nn.functional.one_hot(train_labels).to(dtype=self.architecture.model_dtype)
        else:
            self._train_targets = train_labels

        if self.architecture.pad_amount > 0:
            warnings.warn('Using padding. If not wanted, set `hyp["net"]["pad_amount"] = 0`')
            ## Uncomfortable shorthand, but basically we pad evenly on all _4_ sides with the pad_amount specified in the original dictionary
            self._train_images = self._pad_images(self._train_images, self.architecture.pad_amount)

    @classmethod
    def _pad_images(cls, image: torch.Tensor, pad_amount: int) -> torch.Tensor:
        return torch.nn.functional.pad(
            image,
            (pad_amount,) * 4,
            "reflect",
        )
