from typing import Callable

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

from spc.dataset import LabelledDataset, IndexFilteredDataset
from spc.dfconst import PLATE_COLUMN, MOA_COLUMN, CONTROL_MOA_NAME


MAX_SAMPLE_SIZE = 10000


class Embedder(nn.Module):
    def __init__(self, encoder: nn.Module, encoder_out_dim: int, embed_dim: int, head_type: str):
        super(Embedder, self).__init__()
        self.encoder = encoder
        if model_config.head_type == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(encoder_out_dim, encoder_out_dim),
                nn.ReLU(),
                nn.Linear(encoder_out_dim, embed_dim),
            )
        elif model_config.head_type == 'linear':
            self.head = torch.nn.Linear(encoder_out_dim, embed_dim)
        elif model_config.head_type == 'identity':
            self.head = torch.nn.Identity()
        else:
            raise ValueError(f"Unknown head type: {model_config.head_type}")

        self._embed_dim = embed_dim

    @property
    def embed_dim(self) -> int:
        return self._embed_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.head(x)
        return x


class RandomTest(nn.Module):
    def __init__(self, dim: int = 128):
        super(RandomTest, self).__init__()
        self.dim = dim

    @property
    def rep_dim(self) -> int:
        return self.dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.rand((x.shape[0], self.dim), device=x.device)


def make_embeddings(
        model: Embedder,
        dataset: LabelledDataset,
        eval_transforms: Callable,
        sampling_strategy: str,
        eval_batch_size: int = 128,
        max_sample_size: int = MAX_SAMPLE_SIZE,
) -> np.ndarray:
    df = dataset.get_df()
    embeddings = np.zeros((len(df), model.embed_dim), dtype=np.float32)
    for plate in df[PLATE_COLUMN].unique():
        plate_embeddings = []
        plate_indices = df[df[PLATE_COLUMN] == plate].index.tolist()
        dataset_filtered_by_plate = IndexFilteredDataset(dataset, plate_indices)

        if sampling_strategy in ['plate', 'plate__random']:
            # compute the batch norm layer statistics on controls for this plate
            plate_df = dataset_filtered_by_plate.get_df()
            controls = []

            # use controls or not when refitting batch norm layers
            if sampling_strategy == 'plate':  # control version
                control_indices = plate_df[plate_df[MOA_COLUMN] == CONTROL_MOA_NAME].index
            else:
                control_indices = plate_df.index

            for idx in control_indices:
                controls.append(dataset_filtered_by_plate[idx])

            model.train()
            controls, _ = default_collate(controls)
            if controls.shape[0] > max_sample_size:
                controls = controls[torch.randperm(controls.shape[0])[:max_sample_size]]
            with torch.no_grad():
                model(eval_transforms(controls))

        # no shuffle, so order of embeddings matches order of plate indices
        dataloader = DataLoader(
            dataset=dataset_filtered_by_plate,
            batch_size=eval_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )

        model.eval()
        with torch.no_grad():
            for (ims, _) in tqdm(dataloader):
                ims = eval_transforms(ims)
                z = model(ims)
                plate_embeddings += z.detach().cpu().numpy().tolist()
        embeddings[plate_indices] = np.array(plate_embeddings)

    return embeddings


class MetadataWeights(nn.Module):
    def __init__(self, n_classes, embed_dim, norm_after_update):
        super(MetadataWeights, self).__init__()
        self.w = torch.nn.Parameter(torch.FloatTensor(n_classes, embed_dim))
        torch.nn.init.xavier_uniform_(self.w)
        self.norm_after_update = norm_after_update
        if self.norm_after_update:
            with torch.no_grad():
                self.w.copy_(F.normalize(self.w, p=2, dim=1))

    def renorm(self):
        with torch.no_grad():
            self.w.copy_(F.normalize(self.w, p=2, dim=1))

    def forward(self):
        return self.w



