from typing import Literal, Optional
import functools

import torch
from torch.nn import functional
import pytorch_lightning as pl

from vis_models.training import Trainer
from vis_models.inference_recording import intermediate_representations
# from vis_models.architectures.utils.layer_accessors import get_penultimate_layer_accessor
from vis_models.architectures.access import get_penultimate_layer, ModelConfig
from vis_models.metrics.representation_similarity.cka import (
    linear_CKA
)
from .cka import linear_CKA as np_linear_CKA


PU_LAYER_NAME = "pu"

LayerItem = tuple[str, torch.nn.Module]

def get_pu_layer_item(
    model: torch.nn.Module,
    model_config: ModelConfig,
) -> LayerItem:
    pu_layer = get_penultimate_layer(model_config, model)
    return PU_LAYER_NAME, pu_layer

def compute_representations(
    model: torch.nn.Module,
    model_config: ModelConfig,
    # model_type: str,
    dataset: pl.LightningDataModule,
    layers: Optional[list[LayerItem]] = None,
) -> dict[str, torch.Tensor]:
    if layers is None:
        layers = [get_pu_layer_item(model, model_config)]
    layer_names = [layer_name for layer_name, _ in layers]

    dataset.setup("test")
    test_loader = dataset.test_dataloader()
    trainer = get_trainer()

    with intermediate_representations(model, layers) as monitored_model:
        rep_computation_task = RepComputationTask(monitored_model, layer_names)
        trainer.test(rep_computation_task, test_loader)
        reps = rep_computation_task.reps
    return reps


@functools.cache
def get_trainer() -> Trainer:
    return Trainer(
        task_name=["representation_computation"],
        accelerator="gpu",
        devices=1,
        max_epochs=1,
        enable_progress_bar=False,
    )

class RepComputationTask(pl.LightningModule):

    def __init__(
        self,
        model: torch.nn.Module,
        monitored_layer_names: list[str],
    ) -> None:
        super().__init__()#prepare_data_per_node=False)

        self.model = model
        self.monitored_layer_names = monitored_layer_names

        self.rep_batches = {}

    def test_step(self, batch, batch_idx: int) -> None:
        # x, y = batch.x, batch.y
        output = self.model(batch)
        for layer_name in self.monitored_layer_names:
            pu_reps = output.layer_reps[layer_name]
            # reps.append(torch.flatten(pu_reps))
            self.rep_batches.setdefault(layer_name, []) \
                .append(torch.flatten(pu_reps.detach().cpu(), 1))

    def on_test_end(self) -> None:
        self.reps = {
            layer_name: torch.cat(layer_batches, 0)
            for layer_name, layer_batches in self.rep_batches.items()
        }

RepDistanceMetric = Literal["l2", "linear_cka"]

def compute_representation_distances(
    reps_1: torch.Tensor,
    reps_2: torch.Tensor,
    metric: RepDistanceMetric,
) -> torch.Tensor:
    if metric == "l2":
        return (
            functional.pairwise_distance(reps_1, reps_2, p=2.0)
            .detach()
            .mean()
        )
    elif metric == "linear_cka":
        # return linear_CKA(reps_1, reps_2)
        return np_linear_CKA(reps_1.numpy(), reps_2.numpy())
    else:
        raise ValueError()
