import torch
from tqdm import tqdm
import wandb
from typing import Union, List
from torch.utils.data import DataLoader
import os
from pydantic import BaseModel
import wandb
from neuro.utils.correlation import pearsonr
from neuro.utils.optimizer import get_lr
from neuro.models.brain_response_predictor.linear_mapper import LinearMapper
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper
from neuro.utils.device import move_model_to_gpus
from neuro.utils.kfold import get_k_folds_on_dataset
from nesim.configs import NesimConfig
from nesim.losses.nesim_loss import NesimLoss
from .dataset import NaturalScenesImageEncodingDataset
from typing import Callable


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


class SchedulerConfig(BaseModel, extra="forbid"):
    step_size: int
    gamma: float
    verbose: bool = False


class NaturalScenesRegressionTrainerConfig(BaseModel, extra="forbid"):
    nesim_config: NesimConfig
    num_epochs: int
    batch_size: int
    learning_rate: float
    momentum: float
    weight_decay: float
    device_ids: List[int]
    image_filenames_and_labels_folder: str
    image_encodings_folder: str
    transforms: Callable
    scheduler_config: Union[None, SchedulerConfig]
    checkpoint_folder: str
    train_on_fold: Union[None, int] = None  ## None = train on all folds
    num_folds: int = 10
    progress: bool = False
    quiet: bool = False
    wandb_log: bool = False
    apply_nesim_every_n_steps: int = 3
    save_checkpoint_every_n_steps: Union[None, int] = None
    save_checkpoint_every_n_steps_folder: Union[None, str] = None


class NaturalScenesRegressionTrainer:
    def __init__(self, config: NaturalScenesRegressionTrainerConfig) -> None:
        assert isinstance(config, NaturalScenesRegressionTrainerConfig)
        self.config = config
        self.dataset = NaturalScenesImageEncodingDataset(
            image_filenames_and_labels_folder=self.config.image_filenames_and_labels_folder,
            image_encodings_folder=self.config.image_encodings_folder,
        )

        self.fold_datasets = get_k_folds_on_dataset(
            dataset=self.dataset, n_splits=config.num_folds, random_state=0
        )
        self.device = f"cuda:{self.config.device_ids[0]}"
        self.is_dataparallel = True if len(self.config.device_ids) > 1 else False

    @torch.no_grad()
    def get_correlation(self, batch: dict, model: Union[LinearMapper, ConvMapper]):
        image_batch = batch["image_encoding"].to(self.device)
        pred = model(image_batch)
        labels = batch["brain_response"]

        correlation_score = pearsonr(
            x=pred.reshape(-1), y=labels.reshape(-1).to(self.device)
        )
        return correlation_score

    def get_loss(self, batch: dict, model: Union[LinearMapper, ConvMapper]):
        image_batch = batch["image_encoding"].to(self.device)
        pred = model(image_batch)

        labels = batch["brain_response"]

        assert (
            pred.shape == labels.shape
        ), f"Expected pred.shape: {pred.shape} to be the same as labels.shape: {labels.shape}"
        loss = torch.nn.functional.mse_loss(pred, labels.to(self.device))
        return loss

    def run_train_step(
        self,
        batch,
        optimizer,
        model: Union[LinearMapper, ConvMapper],
        train_step_idx: int,
        nesim_loss: NesimLoss,
        scheduler=None,
    ):
        optimizer.zero_grad()
        loss = self.get_loss(batch=batch, model=model)

        if self.config.wandb_log:
            wandb.log({"train_loss": loss.item(), "lr": get_lr(optimizer)})

        if train_step_idx % self.config.apply_nesim_every_n_steps == 0:
            nesim_loss_item = nesim_loss.compute(reduce_mean=True)
            loss += nesim_loss_item

            if self.config.wandb_log:
                nesim_loss.wandb_log()

        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        if self.config.save_checkpoint_every_n_steps is not None:
            if train_step_idx % self.config.save_checkpoint_every_n_steps == 0:
                assert self.config.save_checkpoint_every_n_steps_folder is not None
                assert os.path.exists(self.config.save_checkpoint_every_n_steps_folder)
                filename = os.path.join(
                    self.config.save_checkpoint_every_n_steps_folder,
                    f"train_step_idx_{train_step_idx}.pth",
                )
                model.save(checkpoint_filename=filename)
                print(f"Saved: {filename}")

        return loss.item()

    def validate(
        self,
        validation_dataloader,
        model: Union[LinearMapper, ConvMapper],
        progress=True,
    ):
        val_losses = []
        correlation_scores = []
        with torch.no_grad():
            for batch in tqdm(validation_dataloader, disable=not (progress)):
                loss = self.get_loss(batch=batch, model=model)
                correlation = self.get_correlation(batch=batch, model=model)
                val_losses.append(loss.item())
                correlation_scores.append(correlation.item())

        mean_validation_loss = torch.tensor(val_losses).mean().item()
        mean_correlation_score = torch.tensor(correlation_scores).mean().item()

        if self.config.wandb_log:
            wandb.log({"validation_loss": sum(val_losses) / len(val_losses)})
            wandb.log({"validation_correlation_score": mean_correlation_score})

        return {
            "validation_loss": mean_validation_loss,
            "correlation_score": mean_correlation_score,
        }

    def train_single_fold(self, fold_idx: int, model: Union[LinearMapper, ConvMapper]):
        model = move_model_to_gpus(model, device_ids=self.config.device_ids)

        nesim_loss = NesimLoss(
            model=model, config=self.config.nesim_config, device=self.device
        )

        trainable_modules = list(model.parameters())
        optimizer = torch.optim.SGD(
            trainable_modules,
            lr=self.config.learning_rate,
            momentum=self.config.momentum,
            weight_decay=self.config.weight_decay,
        )

        if self.config.scheduler_config is not None:
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer=optimizer,
                step_size=self.config.scheduler_config.step_size,
                gamma=self.config.scheduler_config.gamma,
                verbose=self.config.scheduler_config.verbose,
                last_epoch=-1,
            )
        else:
            scheduler = None

        train_dataset = self.fold_datasets[fold_idx]["train"]
        validation_dataset = self.fold_datasets[fold_idx]["validation"]

        print(f"[fold: {fold_idx}] Train dataset: {len(train_dataset)} items")
        print(f"[fold: {fold_idx}] Validation dataset: {len(validation_dataset)} items")

        train_dataloader = DataLoader(
            train_dataset, batch_size=self.config.batch_size, shuffle=True
        )
        validation_dataloader = DataLoader(
            validation_dataset, batch_size=len(validation_dataset), shuffle=True
        )

        train_losses = []
        validation_losses = []
        correlation_scores = []

        checkpoint_filename = os.path.join(
            self.config.checkpoint_folder, f"fold_{fold_idx}.pth"
        )

        train_step_idx = 0
        for epoch_idx in range(self.config.num_epochs):
            # if on first epoch, validate once before running any train steps
            if epoch_idx == 0:
                validation_result = self.validate(
                    validation_dataloader=validation_dataloader,
                    progress=self.config.progress,
                    model=model,
                )
                validation_losses.append(validation_result["validation_loss"])
                correlation_scores.append(validation_result["correlation_score"])

                if not self.config.quiet:
                    print(
                        f"[epoch: {str(epoch_idx).zfill(3)}] best validation loss:",
                        validation_result["validation_loss"],
                        f"correlation score: {validation_result['correlation_score']}",
                    )

            ## run training loop letsgo
            for batch in tqdm(
                train_dataloader,
                disable=not (self.config.progress),
                desc=f"Training epoch: {epoch_idx} ",
            ):
                loss = self.run_train_step(
                    optimizer=optimizer,
                    scheduler=scheduler,
                    batch=batch,
                    model=model,
                    nesim_loss=nesim_loss,
                    train_step_idx=train_step_idx,
                )
                train_losses.append(loss)
                train_step_idx += 1

            validation_result = self.validate(
                validation_dataloader=validation_dataloader,
                progress=self.config.progress,
                model=model,
            )

            if validation_result["validation_loss"] < min(validation_losses):
                if not self.config.quiet:
                    print(
                        f"[epoch: {str(epoch_idx).zfill(3)}] best validation loss:",
                        validation_result["validation_loss"],
                        f"correlation score: {validation_result['correlation_score']}",
                    )

                if checkpoint_filename is not None:
                    if self.is_dataparallel:
                        model.module.save(
                            checkpoint_filename=checkpoint_filename,
                        )
                    else:
                        model.save(
                            checkpoint_filename=checkpoint_filename,
                        )
            validation_losses.append(validation_result["validation_loss"])
            correlation_scores.append(validation_result["correlation_score"])

        return {
            "train_losses": train_losses,
            "validation_losses": validation_losses,
            "correlation_scores": correlation_scores,
            "checkpoint_filename": checkpoint_filename,
        }
