from typing import Any, Dict, Tuple
import os, shutil

import torch
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy
from huggingface_hub import create_repo, upload_folder

from src.models.components.losses import clip_contrastive_loss, clip_contrastive_loss_negative

from src.utils import RankedLogger

logging = RankedLogger(__name__, rank_zero_only=True)

class LitLitModule(LightningModule):
    def __init__(
        self,
        text_encoder: torch.nn.Module,
        vision_encoder: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile: bool,
        huggingface: dict,
        negclip: bool = True,
        go_brr: bool = True,
        train_text: bool = True,
        train_vision: bool = True,
        batch_size: int = 128,
        temperature: float = -1,
    ) -> None:

        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        if temperature == -1:
            logging.warn("Initializing the trainable temperature parameter.")
            self.temperature = torch.nn.Parameter(torch.ones([]) * 0.07)
        else:
            self.temperature = temperature

        if self.hparams.go_brr:
            logging.info("GPUs go brrrrrrr on the DDP!!")

        if not train_text:
            logging.info("Text encoder is set as evaluation only")
            self.text_encoder = self.text_encoder.eval().require_grad(False)
        if not train_vision:
            logging.info("Vision enoder is set as evalaution only")
            # self.vision_encoder = self.vision_encoder.eval().require_grad(False)
            for param in self.vision_encoder.parameters():
                  param.requires_grad=False


        # loss function
        self.criterion = clip_contrastive_loss
        self.neg_criterion = clip_contrastive_loss_negative

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = MeanMetric()
        self.val_acc = MeanMetric()
        self.test_acc = MeanMetric()

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_acc_best = MaxMetric()

    def forward(self, text_input, image_input):
        return self.text_encoder(**text_input), self.vision_encoder(image_input)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.train_loss.reset()
        self.val_loss.reset()
        self.train_acc.reset()
        self.val_acc.reset()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        image_input, text_input = batch
        text_feats, image_feats = self.forward(text_input, image_input)
        loss, accuracy = self.criterion(image_feats, text_feats.text_embeds, self.temperature)
        return loss, accuracy
    
    def model_step_train(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        image_input, negative_image_input, text_input, negative_text_input = batch
        text_feats, image_feats = self.forward(text_input, image_input)
        negative_text_feats, negative_image_feats = self.forward(negative_text_input, negative_image_input)

        if self.hparams.go_brr:
            text_feats = self.all_gather(text_feats.text_embeds, sync_grads=True)
            text_feats = text_feats.view(-1, text_feats.size(2))
            
            image_feats = self.all_gather(image_feats, sync_grads=True)
            image_feats = image_feats.view(-1, image_feats.size(2))

            negative_text_feats = self.all_gather(negative_text_feats.text_embeds, sync_grads=True)
            negative_text_feats = negative_text_feats.view(-1, negative_text_feats.size(2))
            
            negative_image_feats = self.all_gather(negative_image_feats, sync_grads=True)
            negative_image_feats = negative_image_feats.view(-1, negative_image_feats.size(2))

            loss1, accuracy1 = self.neg_criterion(image_feats, text_feats, negative_text_feats, self.temperature)
            loss2, accuracy2 = self.neg_criterion(negative_image_feats, negative_text_feats, text_feats, self.temperature)

        else:
            loss1, accuracy1 = self.neg_criterion(image_feats, text_feats.text_embeds, negative_text_feats.text_embeds, self.temperature)
            loss2, accuracy2 = self.neg_criterion(negative_image_feats, negative_text_feats.text_embeds, text_feats.text_embeds, self.temperature)

        loss = loss1 + loss2
        accuracy = (accuracy1 + accuracy2)/2

        return loss, accuracy

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """
        loss, accuracy = self.model_step_train(batch)

        # update and log metrics
        self.train_loss(loss)
        self.train_acc(accuracy)
        self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        if (
            self.trainer.is_global_zero
            and self.hparams.huggingface.upload
            and self.hparams.huggingface.validation
        ):
            
            self.text_encoder.save_pretrained(f"{self.hparams.huggingface.ckpt_dir}/text-encoder")
            if self.hparams.train_vision:
                self.vision_encoder.save_pretrained(f"{self.hparams.huggingface.ckpt_dir}/vision-encoder")
            shutil.copyfile("src/models/components/utils.py", os.path.join(self.hparams.huggingface.ckpt_dir, "utils.py"))
            
            repo_id = create_repo(
                repo_id = f"{self.hparams.huggingface.username}/{self.hparams.huggingface.name}-epoch-{self.current_epoch}",
                exist_ok=True,
                private=True,
            ).repo_id
            upload_folder(
                folder_path=self.hparams.huggingface.ckpt_dir,
                repo_id=repo_id,
            )

    def on_validation_start(self) -> None:
        """Lightning hook that is called when validation begins."""
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, accuracy = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        self.val_acc(accuracy)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        acc = self.val_acc.compute()  # get current val acc
        self.val_acc_best(acc)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, accuracy = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.test_acc(accuracy)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_start(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        if self.hparams.huggingface.upload and self.hparams.huggingface.eval:
            self.text_encoder.save_pretrained(f"{self.hparams.huggingface.ckpt_dir}/text-encoder")
            if self.hparams.train_vision:
                self.vision_encoder.save_pretrained(f"{self.hparams.huggingface.ckpt_dir}/vision-encoder")
            shutil.copyfile("src/models/components/utils.py", os.path.join(self.hparams.huggingface.ckpt_dir, "utils.py"))
            
            repo_id = create_repo(
                repo_id = f"{self.hparams.huggingface.username}/{self.hparams.huggingface.name}-best",
                exist_ok=True,
                private=True,
            ).repo_id
            upload_folder(
                folder_path=self.hparams.huggingface.ckpt_dir,
                repo_id=repo_id,
                repo_type="model"
            )

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass

    def setup(self, stage: str) -> None:
        """Lightning hook that is called at the beginning of fit (train + validate), validate,
        test, or predict.

        This is a good hook when you need to build models dynamically or adjust something about
        them. This hook is called on every process when using DDP.

        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
        """
        if self.hparams.compile and stage == "fit":
            self.text_encoder = torch.compile(self.text_encoder)
            self.vision_encoder = torch.compile(self.vision_encoder)

    def configure_optimizers(self) -> Dict[str, Any]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        optimizer = self.hparams.optimizer(params=self.parameters())
        scheduler = self.hparams.scheduler(
            optimizer=optimizer,
            total_steps=self.trainer.estimated_stepping_batches,
        )
        return [optimizer], [
            {"name": "train/lr", "scheduler": scheduler, "interval": "step"}
        ]
