import lightning.pytorch as pl
import torch

from ..denovo.model import Spec2Pep
from .discriminator import DiscriminationModel

import numpy as np

from matplotlib import pyplot as plt
from io import BytesIO
from PIL import Image

class GANOVO(pl.LightningModule):
    def __init__(self,
                 denovo_model: Spec2Pep,
                 discriminator_model: DiscriminationModel,
                 tokenizer,
                 msv_V2_tokenizer,
                 adv_loss_fn = torch.nn.BCELoss(),
                 regression_loss_fn = torch.nn.MSELoss(),
                 denovo_scaling = 1.,
                 adv_scaling = 50.,):
        super().__init__()
        self.automatic_optimization = False

        self.denovo_model = denovo_model
        self.discriminator = discriminator_model
        self.adv_loss_fn = adv_loss_fn
        self.regression_loss_fn = regression_loss_fn
        self.denovo_scaling = denovo_scaling
        self.adv_scaling = adv_scaling
        self.tokenizer = tokenizer
        self.msv_v2_tokenizer = msv_V2_tokenizer

        self.first_in_epoch = True

    def _log_loss(self,
                  mode: str,
                  dir: str,
                  name: str,
                  loss: torch.tensor,
                  bs: int,
                  on_epoch: bool = False):
        path = f"{mode}_loss"
        if dir is not None:
            path += "/" + dir
        if name is not None:
            path += "/" + name

        self.log(path,
                 loss.detach(),
                 on_step=not on_epoch,
                 on_epoch=on_epoch,
                 sync_dist=True,
                 batch_size=bs)
        
    def _log_precitions_plot(self,
                             name,
                             pred_masses,
                             true_masses,
                             truth,
                             tokenizer):
        preds_np = pred_masses.detach().cpu().numpy()
        true_np = true_masses.detach().cpu().numpy()
        true_tokens_np = truth.detach().cpu().numpy()

        mask = true_np>0 # filter padding
        filtered_preds = preds_np[mask]
        filtered_classes = true_tokens_np[mask]

        mapped_classes = np.array([tokenizer.reverse_index[idx] for idx in filtered_classes])
        class_labels = sorted(tokenizer.reverse_index[1:])
        boxplot_data = [filtered_preds[mapped_classes == label] for label in class_labels]

        fig, ax = plt.subplots()
        ax.boxplot(boxplot_data, labels=class_labels, patch_artist=True)

        ax.set_xlabel("Residue")
        ax.set_ylabel("Predicted Masses")
        ax.set_title(f"Predicted Masses by Residue Class")
        ax.grid(axis="y")

        ax.tick_params(axis='x', labelrotation=90)

        # Convert plot to tensor
        buf = BytesIO()
        plt.savefig(buf, format="png")
        plt.close(fig)  # Close figure to free memory
        buf.seek(0)

        image = Image.open(buf).convert("RGB")
        image_np = np.array(image, dtype=np.uint8)
        image_tensor = torch.from_numpy(image_np).permute(2, 0, 1)  # Convert to PyTorch tensor

        # Log image to TensorBoard
        self.logger.experiment.add_image(f"Box_Plot/pred_mass_{name}", image_tensor, self.global_step)

    def training_step(self,
                      batches,
                      *args):
        opt_g, opt_d = self.optimizers()
        scheds = self.lr_schedulers()
        scheds = scheds if isinstance(scheds, list) else [scheds]

        assert len(batches) == 2, f"The number of batches provided during training is not 2, but was {len(batches)} with batch={batches}"

        real_batch, simulated_batch = batches['real'], batches['simulated']

        # Run de novo model on both batches
        real_enc, real_enc_masks, real_precursors, _, real_true_masses = self.denovo_model._forward_step_enc(real_batch)
        simulated_enc, simulated_enc_masks, simulated_precursors, _, simulated_true_masses = self.denovo_model._forward_step_enc(simulated_batch)
        _, real_predicted_masses = self.denovo_model._forward_step_dec(real_enc,
                                                                       real_enc_masks, real_precursors, real_true_masses)
        real_predicted_masses = real_predicted_masses.squeeze(-1).squeeze(-1)
        real_mask = real_true_masses != 0
        _, simulated_predicted_masses = self.denovo_model._forward_step_dec(simulated_enc,
                                                                            simulated_enc_masks, simulated_precursors, simulated_true_masses)
        simulated_predicted_masses = simulated_predicted_masses.squeeze(-1).squeeze(-1)
        simulated_mask = simulated_true_masses != 0
        
        self.discriminator.train()
        self.denovo_model.train()

        # Optimze Discriminator
        with torch.no_grad():
            real_enc_d = real_enc.detach()
            real_mask_d = real_enc_masks.detach()
            simulated_enc_d = simulated_enc.detach()
            simulated_mask_d = simulated_enc_masks.detach()
        
        d_batch = {
            'real': {
                'encoding': real_enc_d,
                'mask': real_mask_d,
            },
            'simulated': {
                'encoding': simulated_enc_d,
                'mask': simulated_mask_d,
            },
        }

        d_loss, d_loss_real, d_loss_sim = self.discriminator.training_step(d_batch)

        self._log_loss("train", "discriminator", "total", d_loss, real_enc_d.shape[0])
        self._log_loss("train", "discriminator", "real", d_loss_real, real_enc_d.shape[0])
        self._log_loss("train", "discriminator", "simulated", d_loss_sim, simulated_enc_d.shape[0])
        
        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()

        # Optimize De Novo model (aka Generator)
        reg_loss_real = self.regression_loss_fn(
            real_predicted_masses[real_mask],
            real_true_masses[real_mask]
        )
        reg_loss_simulation = self.regression_loss_fn(
            simulated_predicted_masses[simulated_mask],
            simulated_true_masses[simulated_mask]
        )
        factor_real = real_predicted_masses.shape[0]/(simulated_predicted_masses.shape[0]+real_predicted_masses.shape[0])
        factor_simulated = simulated_predicted_masses.shape[0]/(simulated_predicted_masses.shape[0]+real_predicted_masses.shape[0])

        # We also need the dirscriminator loss again, but with its own gradient
        d_batch["real"]["encoding"] = real_enc              # No detach, to get gradients of trans. enc.
        d_batch["simulated"]["encoding"] = simulated_enc    # No detach, to get gradients of trans. enc.
        adv_loss, _, _ = self.discriminator.training_step(d_batch)

        de_novo_loss = (
            # Regression losses for real and simulated data
            factor_real*reg_loss_real
            + factor_simulated*reg_loss_simulation
            # Regularization via adversarial loss of discriminator
            - self.adv_scaling*adv_loss
        )

        self._log_loss("train", "denovo", "real_reg", reg_loss_real, real_predicted_masses.shape[0])
        self._log_loss("train", "denovo", "sim_reg", reg_loss_simulation, simulated_predicted_masses.shape[0])
        self._log_loss("train", "denovo", "total", de_novo_loss, real_predicted_masses.shape[0])

        opt_g.zero_grad()
        self.manual_backward(de_novo_loss)
        opt_g.step()

        # Learning rate schedulers
        for sched in scheds:
            sched.step()

        return
    
    def validation_step(self,
                        batches,
                        *args):
        self.discriminator.eval()
        self.denovo_model.eval()

        assert len(batches) == 2, f"The number of batches provided during training is not 2, but was {len(batches)} with batch={batches}"

        dev_batch, zero_batch = batches['seen'], batches['unseen']

        batches = [dev_batch, zero_batch]
        names = ["dev", "zero"]
        msv_versions = ["v1", "v2"]
        tokenizers = [self.tokenizer, self.msv_v2_tokenizer]

        for batch, name, v, tokenizer in zip(batches, names, msv_versions, tokenizers):
            if batch is None:
                break

            _, pred_masses, truth, true_masses = self.denovo_model._forward_step(batch)
            pred_masses = pred_masses.squeeze(-1).squeeze(-1)
            mask = true_masses==0

            loss = self.regression_loss_fn(pred_masses[~mask], true_masses[~mask])

            #--------- Logging
            self._log_loss(f'validation_{name}', None, None, loss, pred_masses.shape[0], on_epoch=True)

            if self.first_in_epoch:
                self._log_precitions_plot(
                    f"msv_{v}",
                    pred_masses,
                    true_masses,
                    truth,
                    tokenizer
                )
        
        self.first_in_epoch = False # To only plot the boxplots once per validation epoch

    def on_validation_epoch_start(self):
        self.first_in_epoch = True

    def configure_optimizers(self):
        def _handle_optimizer(config):
            # Unpack the optimzers and schedulers from the models
            if isinstance(config, tuple):
                opts = config[0]
                scheds = config[1]
                if isinstance(scheds, dict):
                    scheds = [scheds]
            else:
                opts = config if isinstance(config, list) else [config]
                scheds = []

            return opts, scheds
    
        # --- De novo ---
        gen_config = self.denovo_model.configure_optimizers()
        gen_opts, gen_scheds = _handle_optimizer(gen_config)

        # --- Discriminator ---
        disc_config = self.discriminator.configure_optimizers()
        disc_opts, disc_scheds = _handle_optimizer(disc_config)

        # --- Combine ---
        optimizers = gen_opts + disc_opts
        schedulers = gen_scheds + disc_scheds

        return optimizers, schedulers