"""Implementing the ParamRepulsor/ParamPaCMAP Algorithm as a sklearn estimator."""
import time
from typing import Optional, Callable

import torch
import torch.optim as optim
import torch.utils.data
import numpy as np
from sklearn import preprocessing, decomposition
from sklearn.base import BaseEstimator

from paramrepulsor.models import module, dataset
from paramrepulsor.utils import data, utils
from source.paramrepulsor.paramrepulsor import training


def pacmap_weight_schedule(epoch):
    if epoch < 100:
        w_mn = 10 * (100 - epoch) + 0.03 * epoch
        w_nn = 2.0
        w_fp = 1.0
    elif epoch < 200:
        w_mn = 3.0
        w_nn = 3.0
        w_fp = 1.0
    else:
        w_mn = 0.0
        w_nn = 1.0
        w_fp = 1.0
    weight = np.array([w_nn, w_fp, w_mn])
    return weight


def paramrepulsor_weight_schedule(epoch):
    if epoch < 200:
        w_nn = 4.0
        w_fp = 8.0
        w_mn = 0.0
    else:
        w_nn = 1.0
        w_fp = 8.0
        w_mn = -12.0
    weight = np.array([w_nn, w_fp, w_mn])
    return weight


class ParamRepulsor(BaseEstimator):
    """ParamRepulsor implemented with Pytorch."""
    def __init__(
        self,
        n_components: int = 2,
        n_neighbors: int = 10,
        n_FP: int = 20,
        n_MN: int = 5,
        distance: str = "euclidean",
        optim_type: str = "Adam",
        lr: float = 1e-3,
        lr_schedule: Optional[bool] = None,
        apply_pca: bool = True,
        apply_scale: Optional[str] = None,
        model_dict: Optional[dict] = utils.DEFAULT_MODEL_DICT,
        intermediate_snapshots: Optional[list] = [],
        loss_weight: Optional[list] = [1, 1, 1],
        batch_size: int = 1024,
        data_reshape = None,
        num_epochs: int = 450,
        verbose: bool = False,
        weight_schedule: Callable = paramrepulsor_weight_schedule, # Change this to pacmap_weight_schedule for parampacmap
        num_workers: int = 12,
        dtype: torch.dtype = torch.float32,
        perform_pca: bool = True,
        use_ns_loader: bool = False,
        consts: list[float] = [10, 1, 1],  # Change this to [10, 1, 10000] for parampacmap
        torch_compile: bool = False,
    ):
        super().__init__()
        self.n_components = n_components  # output_dims
        self.n_neighbors = n_neighbors
        self.n_FP = n_FP
        self.n_MN = n_MN
        self.distance = distance
        self.optim_type = optim_type
        self.lr = lr
        self.lr_schedule = lr_schedule
        self.apply_pca = apply_pca
        self.apply_scale = apply_scale
        # Placeholder for the model. The model is initialized during fit.
        self.model = None
        self.model_dict = model_dict
        self.intermediate_snapshots = intermediate_snapshots
        self.loss_weight = loss_weight
        self.batch_size = batch_size
        self.data_reshape = data_reshape
        self.num_epochs = num_epochs
        self.verbose = verbose
        self.weight_schedule = weight_schedule
        self.num_workers = num_workers
        self._dtype = dtype
        self.perform_pca = perform_pca
        self._scaler = None
        self._projector = None
        self.time_profiles = None
        self.use_ns_loader = use_ns_loader
        self.consts = consts
        self.torch_compile = torch_compile

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

    def fit(self, X: np.ndarray, profile_only: bool = False, per_layer: bool = False) -> None:
        fit_begin = time.perf_counter()
        input_dims = X.shape[1]

        # Data Preprocessing
        if input_dims > 100 and self.perform_pca:
            self._projector = decomposition.PCA(n_components=100)
            X = self._projector.fit_transform(X)
            input_dims = X.shape[1]
        if self.apply_scale == "standard":
            self._scaler = preprocessing.StandardScaler()
            X = self._scaler.fit_transform(X)
        elif self.apply_scale == "minmax":
            self._scaler = preprocessing.MinMaxScaler()
            X = self._scaler.fit_transform(X)

        self.model = module.ParamPaCMAP(
            input_dims=input_dims,
            output_dims=self.n_components,
            model_dict=self.model_dict
        ).to(self.device).to(self._dtype)
        self.loss = module.PaCMAPLoss(
            weight=self.loss_weight,
            consts=self.consts
        ).to(self._dtype)
        self.intermediate_outputs = []

        # Constructing dataloader
        pair_neighbors, pair_MN, pair_FP, _ = data.generate_pair(
            X, n_neighbors=self.n_neighbors, n_MN=self.n_MN, n_FP=self.n_FP,
            distance=self.distance, verbose=False
        )
        nn_pairs, fp_pairs, mn_pairs = training.convert_pairs(
            pair_neighbors, pair_FP, pair_MN, X.shape[0])

        if self.use_ns_loader:
            train_loader_ctor = dataset.FastNSDataloader
        else:
            train_loader_ctor = dataset.FastDataloader
        train_loader = train_loader_ctor(
            data=X,
            nn_pairs=nn_pairs,
            fp_pairs=fp_pairs,
            mn_pairs=mn_pairs,
            batch_size=self.batch_size,
            device=self.device,
            shuffle=True,
            reshape=self.data_reshape,
            dtype=self._dtype,
        )
        test_set = dataset.TensorDataset(data=X, reshape=self.data_reshape, dtype=self._dtype)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, 
            batch_size=2 * self.batch_size,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )

        # Construct optimizer
        if self.optim_type == "Adam":
            optimizer = optim.Adam(self.model.parameters(),
                                   lr=self.lr)
        elif self.optim_type == "SGD":
            optimizer = optim.SGD(self.model.parameters(), 
                                  lr=self.lr, momentum=0.9)
        else:
            raise ValueError(f"Unsupported optimizer type: {self.optim_type}")

        if profile_only:
            epoch_begin = time.perf_counter()
            print(f"Time Profile: Before Epoch\n"
                  f"Preparation:{(epoch_begin - fit_begin):03.3f}s\n")
            self._tune_weight(epoch=0)
            self._profile_epoch(train_loader, optimizer)
            self._embedding = None
            return

        if self.use_ibns_loader:
            train_func = self._train_epoch_ib
        else:
            train_func = self._train_epoch
        if self.torch_compile:
            train_func = torch.compile(train_func)

        for epoch in range(self.num_epochs):
            if epoch in self.intermediate_snapshots:
                if per_layer:
                    result = self._inference_per_layer(test_loader)
                else:
                    result = self._inference(test_loader)
                self.intermediate_outputs.append(result)
            # Tune the weights 
            self._tune_weight(epoch=epoch)

            # Perform training for one epoch
            train_func(train_loader, epoch, optimizer)

        if per_layer:
            self._embedding = self._inference_per_layer(test_loader)
        else:
            self._embedding = self._inference(test_loader)

    def _tune_weight(self, epoch: int):
        """Automatically tune the weight."""
        # Decide weight based on the functions
        weight = self.weight_schedule(epoch)
        self.loss.update_weight(weight)

    def _train_epoch(self, train_loader, epoch,
                     optimizer: optim.Optimizer):
        """Perform a single epoch of training."""
        for batch in train_loader:
            optimizer.zero_grad()
            num_items, model_input = batch
            model_output = self.model(model_input)
            basis = model_output[:num_items]
            nn_pairs = model_output[num_items:num_items * (self.n_neighbors + 1)]
            fp_pairs = model_output[num_items * (self.n_neighbors + 1):num_items * (self.n_neighbors + self.n_FP + 1)]
            mn_pairs = model_output[num_items * (self.n_neighbors + self.n_FP + 1):]
            basis = torch.unsqueeze(basis, 1)
            nn_pairs = nn_pairs.view(num_items, self.n_neighbors, nn_pairs.shape[1])
            fp_pairs = fp_pairs.view(num_items, self.n_FP, fp_pairs.shape[1])
            mn_pairs = mn_pairs.view(num_items, self.n_MN, mn_pairs.shape[1])
            loss = self.loss(basis, nn_pairs, fp_pairs, mn_pairs)
            loss.backward()
            optimizer.step()
        if ((epoch + 1) % 20 == 0 or epoch == 0) and self.verbose:
            print(
                f'Epoch [{epoch + 1}/{self.num_epochs}], Loss: {loss.item():.4f},',
                flush=True
            )

    def _profile_epoch(self, train_loader,
                      optimizer: optim.Optimizer):
        """Perform a single epoch of training with detailed profiling."""
        time_profiles = []
        batch_begin = time.perf_counter()
        for batch in train_loader:
            torch.cuda.synchronize()
            time_dataloader = time.perf_counter()
            optimizer.zero_grad()
            # The pairs are under the format (i, num_pairs, ...)
            num_items, model_input = batch
            # model_input = torch.concat((basis, nn_pairs, fp_pairs, mn_pairs), dim=0)
            torch.cuda.synchronize()
            time_reshape = time.perf_counter()
            # Use the model to perform forward
            model_output = self.model(model_input)
            basis = model_output[:num_items]
            nn_pairs = model_output[num_items:num_items * (self.n_neighbors + 1)]
            fp_pairs = model_output[num_items * (self.n_neighbors + 1):num_items * (self.n_neighbors + self.n_FP + 1)]
            mn_pairs = model_output[num_items * (self.n_neighbors + self.n_FP + 1):]
            torch.cuda.synchronize()
            time_forward = time.perf_counter()
            basis = torch.unsqueeze(basis, 1)
            nn_pairs = nn_pairs.view(num_items, self.n_neighbors, nn_pairs.shape[1])
            fp_pairs = fp_pairs.view(num_items, self.n_FP, fp_pairs.shape[1])
            mn_pairs = mn_pairs.view(num_items, self.n_MN, mn_pairs.shape[1])
            loss = self.loss(basis, nn_pairs, fp_pairs, mn_pairs)
            loss.backward()
            optimizer.step()
            torch.cuda.synchronize()
            time_backward = time.perf_counter()
            time_series = [
                time_dataloader - batch_begin,
                time_reshape - time_dataloader,
                time_forward - time_reshape,
                time_backward - time_forward,
            ]
            batch_begin = time_backward
            time_profiles.append(time_series)
        self.time_profiles = np.array(time_profiles)
        # Generate a profile report
        time_summary = np.sum(self.time_profiles, axis=0)
        summary_text = (
            f"Time Profile: Sum in Epoch\n"
            f"Dataloader: {time_summary[0]:03.3f}s\n"
            f"Reshape:    {time_summary[1]:03.3f}s\n"
            f"Forward:    {time_summary[2]:03.3f}s\n"
            f"Backward:   {time_summary[3]:03.3f}s\n"
        )
        print(summary_text)

    def _inference(self, test_loader):
        """Perform a pure inference for the model."""
        results = []
        with torch.inference_mode():
            for batch in test_loader:
                results.append(self.model(batch.to(self.device)).detach())
            results = torch.concatenate(results)
            results = results.float().cpu().numpy()
        return results

    def _inference_per_layer(self, test_loader):
        """Perform a pure inference for the model."""
        self.model.set_output_per_layer(True)
        results = []
        with torch.inference_mode():
            for batch in test_loader:
                result = self.model(batch.to(self.device))  # A list of multiple embeddings
                results.append(result)
        all_same_size = all(len(result) == len(results[0]) for result in results)
        assert all_same_size
        num_layers = len(results[0])
        layer_results = []
        for i in range(num_layers):
            sub_result = [result[i] for result in results]
            layer_result = torch.concatenate(sub_result).float().cpu().numpy()
            layer_results.append(layer_result)
        self.model.set_output_per_layer(False)
        return layer_results

    def fit_transform(self, X: np.ndarray, per_layer: bool=False):
        self.fit(X, per_layer=per_layer)
        if len(self.intermediate_outputs) == 0:
            return self._embedding
        return self._embedding, self.intermediate_outputs

    def transform(self, X: np.ndarray, per_layer: bool=False) -> np.ndarray:
        if self._projector is not None:
            X = self._projector.transform(X)
        if self._scaler is not None:
            X = self._scaler.transform(X)
        test_set = dataset.TensorDataset(data=X, reshape=self.data_reshape, dtype=self._dtype)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set, 
            batch_size=2 * self.batch_size,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=self.num_workers,
            persistent_workers=True
        )
        if per_layer:
            return self._inference_per_layer(test_loader)
        return self._inference(test_loader)

