import torch
from tqdm import tqdm
import wandb
from typing import Union, List
from torch.utils.data import DataLoader
import os
from pydantic import BaseModel
import wandb
from neuro.utils.correlation import pearsonr
from neuro.utils.optimizer import get_lr
from neuro.models.brain_response_predictor.linear_mapper import LinearMapper
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper
from neuro.utils.device import move_model_to_gpus
from neuro.datasets.image_encoding_dataset.dataset import ImageEncodingDataset
from neuro.utils.kfold import get_k_folds_on_dataset
from nesim.configs import NesimConfig
from nesim.losses.nesim_loss import NesimLoss


class SchedulerConfig(BaseModel, extra="forbid"):
    step_size: int
    gamma: float
    verbose: bool = False


class NesimRegressionTrainerConfig(BaseModel, extra="forbid"):
    nesim_config: NesimConfig
    num_epochs: int
    batch_size: int
    learning_rate: float
    momentum: float
    weight_decay: float
    device_ids: List[int]
    image_filenames_and_labels_folder: str
    image_encodings_folder: str
    scheduler_config: Union[None, SchedulerConfig]
    checkpoint_folder: str
    train_on_fold: Union[None, int] = None  ## None = train on all folds
    num_folds: int = 10
    progress: bool = False
    quiet: bool = False
    wandb_log: bool = False
    apply_nesim_every_n_steps: int = 3
    save_checkpoint_every_n_steps: Union[None, int] = None
    save_checkpoint_every_n_steps_folder: Union[None, str] = None


class NesimRegressionTrainer:
    def __init__(
        self,
        config: NesimRegressionTrainerConfig,
    ):
        assert isinstance(config, NesimRegressionTrainerConfig)
        self.config = config
        self.dataset = ImageEncodingDataset(
            image_filenames_and_labels_folder=self.config.image_filenames_and_labels_folder,
            image_encodings_folder=self.config.image_encodings_folder,
        )
        self.regions_of_interest = ["lffa", "rffa", "lppa", "rppa", "leba", "reba"]

        self.fold_datasets = get_k_folds_on_dataset(
            dataset=self.dataset, n_splits=config.num_folds, random_state=0
        )

        self.device = f"cuda:{self.config.device_ids[0]}"
        self.is_dataparallel = True if len(self.config.device_ids) > 1 else False
        self.neurotypical_subject_ids = ["p1", "p2", "p3", "p4"]
        self.num_output_neurons = 6

    def get_mean_activations_of_all_subjects(self, batch: dict, batch_size: int):
        """returns the mean FMRI response for each roi for all neurotypical subjects

        Args:
            batch (dict): dict containing fmri response from each subject
            batch_size (int): number of items in batch

        Raises:
            KeyError: This happens if the roi name or the subject name is invalid

        Returns:
            torch.tensor: the mean brain response for each batch item
        """
        labels_for_all_subjects = []

        for subject_id in self.neurotypical_subject_ids:
            labels = torch.zeros(batch_size, self.num_output_neurons)

            for key_idx, key in enumerate(self.regions_of_interest):
                try:
                    labels[:, key_idx] = batch["fmri_response"][subject_id][key]
                except KeyError:
                    raise KeyError(
                        f"Could not find roi: {key} for subject_id: {subject_id}"
                    )
            labels_for_all_subjects.append(labels.unsqueeze(0))

        ## shape: (num subjects, batch, roi)
        labels_for_all_subjects = torch.cat(labels_for_all_subjects, dim=0)
        mean_labels_for_all_subjects = labels_for_all_subjects.mean(0)

        return mean_labels_for_all_subjects

    @torch.no_grad()
    def get_correlation(self, batch: dict, model: Union[LinearMapper, ConvMapper]):
        image_batch = batch["image_encoding"].to(self.device)
        pred = model(image_batch)
        labels = self.get_mean_activations_of_all_subjects(
            batch=batch, batch_size=image_batch.shape[0]
        )

        correlation_score = pearsonr(
            x=pred.reshape(-1), y=labels.reshape(-1).to(self.device)
        )
        return correlation_score

    def get_loss(self, batch: dict, model: Union[LinearMapper, ConvMapper]):
        image_batch = batch["image_encoding"].to(self.device)
        pred = model(image_batch)

        labels = self.get_mean_activations_of_all_subjects(
            batch=batch, batch_size=image_batch.shape[0]
        )

        assert (
            pred.shape == labels.shape
        ), f"Expected pred.shape: {pred.shape} to be the same as labels.shape: {labels.shape}"
        loss = torch.nn.functional.mse_loss(pred, labels.to(self.device))
        return loss

    def run_train_step(
        self,
        batch,
        optimizer,
        model: Union[LinearMapper, ConvMapper],
        train_step_idx: int,
        nesim_loss: NesimLoss,
        scheduler=None,
    ):
        optimizer.zero_grad()
        loss = self.get_loss(batch=batch, model=model)

        if self.config.wandb_log:
            wandb.log({"train_loss": loss.item(), "lr": get_lr(optimizer)})

        if train_step_idx % self.config.apply_nesim_every_n_steps == 0:
            nesim_loss_item = nesim_loss.compute(reduce_mean=True)

            if nesim_loss_item is not None:
                loss += nesim_loss_item

            if self.config.wandb_log:
                nesim_loss.wandb_log()

        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        if self.config.save_checkpoint_every_n_steps is not None:
            if train_step_idx % self.config.save_checkpoint_every_n_steps == 0:
                assert self.config.save_checkpoint_every_n_steps_folder is not None
                assert os.path.exists(self.config.save_checkpoint_every_n_steps_folder)
                filename = os.path.join(
                    self.config.save_checkpoint_every_n_steps_folder,
                    f"train_step_idx_{train_step_idx}.pth",
                )
                model.save(checkpoint_filename=filename)
                print(f"Saved: {filename}")

        return loss.item()

    def validate(
        self,
        validation_dataloader,
        model: Union[LinearMapper, ConvMapper],
        progress=True,
    ):
        val_losses = []
        correlation_scores = []
        with torch.no_grad():
            for batch in tqdm(validation_dataloader, disable=not (progress)):
                loss = self.get_loss(batch=batch, model=model)
                correlation = self.get_correlation(batch=batch, model=model)
                val_losses.append(loss.item())
                correlation_scores.append(correlation.item())

        mean_validation_loss = torch.tensor(val_losses).mean().item()
        mean_correlation_score = torch.tensor(correlation_scores).mean().item()

        if self.config.wandb_log:
            wandb.log({"validation_loss": sum(val_losses) / len(val_losses)})
            wandb.log({"validation_correlation_score": mean_correlation_score})

        return {
            "validation_loss": mean_validation_loss,
            "correlation_score": mean_correlation_score,
        }

    def train_single_fold(self, fold_idx: int, model: Union[LinearMapper, ConvMapper]):
        model = move_model_to_gpus(model, device_ids=self.config.device_ids)

        nesim_loss = NesimLoss(
            model=model, config=self.config.nesim_config, device=self.device
        )

        trainable_modules = list(model.parameters())
        optimizer = torch.optim.SGD(
            trainable_modules,
            lr=self.config.learning_rate,
            momentum=self.config.momentum,
            weight_decay=self.config.weight_decay,
        )

        if self.config.scheduler_config is not None:
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer=optimizer,
                step_size=self.config.scheduler_config.step_size,
                gamma=self.config.scheduler_config.gamma,
                verbose=self.config.scheduler_config.verbose,
                last_epoch=-1,
            )
        else:
            scheduler = None

        train_dataset = self.fold_datasets[fold_idx]["train"]
        validation_dataset = self.fold_datasets[fold_idx]["validation"]

        train_dataloader = DataLoader(
            train_dataset, batch_size=self.config.batch_size, shuffle=True
        )
        validation_dataloader = DataLoader(
            validation_dataset, batch_size=len(validation_dataset), shuffle=True
        )

        train_losses = []
        validation_losses = []
        correlation_scores = []

        checkpoint_filename = os.path.join(
            self.config.checkpoint_folder, f"fold_{fold_idx}.pth"
        )

        train_step_idx = 0
        for epoch_idx in range(self.config.num_epochs):
            # if on first epoch, validate once before running any train steps
            if epoch_idx == 0:
                validation_result = self.validate(
                    validation_dataloader=validation_dataloader,
                    progress=self.config.progress,
                    model=model,
                )
                validation_losses.append(validation_result["validation_loss"])
                correlation_scores.append(validation_result["correlation_score"])

                if not self.config.quiet:
                    print(
                        f"[epoch: {str(epoch_idx).zfill(3)}] best validation loss:",
                        validation_result["validation_loss"],
                        f"correlation score: {validation_result['correlation_score']}",
                    )

            ## run training loop letsgo
            for batch in tqdm(
                train_dataloader,
                disable=not (self.config.progress),
                desc=f"Training epoch: {epoch_idx} ",
            ):
                loss = self.run_train_step(
                    optimizer=optimizer,
                    scheduler=scheduler,
                    batch=batch,
                    model=model,
                    nesim_loss=nesim_loss,
                    train_step_idx=train_step_idx,
                )
                train_losses.append(loss)
                train_step_idx += 1

            validation_result = self.validate(
                validation_dataloader=validation_dataloader,
                progress=self.config.progress,
                model=model,
            )

            if validation_result["validation_loss"] < min(validation_losses):
                if not self.config.quiet:
                    print(
                        f"[epoch: {str(epoch_idx).zfill(3)}] best validation loss:",
                        validation_result["validation_loss"],
                        f"correlation score: {validation_result['correlation_score']}",
                    )

                if checkpoint_filename is not None:
                    if self.is_dataparallel:
                        model.module.save(
                            checkpoint_filename=checkpoint_filename,
                        )
                    else:
                        model.save(
                            checkpoint_filename=checkpoint_filename,
                        )
            validation_losses.append(validation_result["validation_loss"])
            correlation_scores.append(validation_result["correlation_score"])

        return {
            "train_losses": train_losses,
            "validation_losses": validation_losses,
            "correlation_scores": correlation_scores,
            "checkpoint_filename": checkpoint_filename,
        }
