#%%
import sys
sys.path.append("..")

import torch
from torch import nn

import pytorch_lightning as pl
import data
from data.promptable_picking_datamodule import PromptablePickingDatamodule
from data.augmentation import AugmentationPipeline
from matplotlib import pyplot as plt
import os
from backup_tools import backup_python_files
import numpy as np
from model.promptable_picker import PromptablePicker
from model.siamese_net_3d.promptable_decoder import SiameseNet3DDecoder
from pytorch_lightning import seed_everything
from train_cfg import datamodule_args, logger, train_loss_args, val_loss_args, model_args, optimizer_args, gpus, augmentation_args, max_epochs, coldstart
# if you want to fine tune a model, comment out the line above and uncomment the line below
# from finetune_cfg import datamodule_args, logger, train_loss_args, val_loss_args, model_args, optimizer_args, gpus, augmentation_args, max_epochs, coldstart

# this may be not needed on some systems
os.environ["NCCL_P2P_LEVEL"] = "PXB"


class MyModel(pl.LightningModule):
    def __init__(self, model_args, optimizer_args, train_loss_args, val_loss_args=None):
        super().__init__()
        self.model = PromptablePicker(**model_args)
        self.optimizer_args = optimizer_args
        self.train_loss_args = train_loss_args
        self.val_loss_args = val_loss_args if val_loss_args is not None else train_loss_args
        self.train_loss_fn = eval(self.train_loss_args["class"])(**self.train_loss_args["class_args"])
        self.val_loss_fn = eval(self.val_loss_args["class"])(**self.val_loss_args["class_args"])
        self.save_hyperparameters()

    def on_train_start(self):
        print("Making code backup...")
        backup_code_dir = f"{self.logger.log_dir}/code_backup"
        backup_python_files(src=".", dest=backup_code_dir, exclude_dirs=["code_backup"])
        print("... done!")

    def forward(self, x, prompt):
        if len(x.shape) == 4:
            x = x.unsqueeze(1)
        assert len(x.shape) == 5, f"Expected 5D input, got {x.shape}"
        assert len(prompt.shape) == 2, f"Expected 2D input, got {prompt.shape}"
        return self.model(x, prompt)


    def get_loss(self, model_output, model_targets, mode, return_per_pdb_losses=False):
        assert mode in ["train", "val"], f"mode must be 'train' or 'val', got {mode}"
        loss_args = self.train_loss_args if mode == "train" else self.val_loss_args
        mask_empty_targets = loss_args["mask_empty_targets"]    
        loss_fn = self.train_loss_fn if mode == "train" else self.val_loss_fn

        if isinstance(model_output, torch.nn.BCELoss):
            # check if model output is in 0-1
            if model_output.min() < 0 or model_output.max() > 1:
                print("Model output is not in 0-1 range. Got min, max: ", model_output.min(), model_output.max())
                model_output = model_output.clamp(0, 1)
            if model_targets.min() < 0 or model_targets.max() > 1:
                print("Model targets is not in 0-1 range. Got min, max: ", model_targets.min(), model_targets.max())
                model_targets = model_targets.clamp(0, 1)
            
        per_pdb_losses = loss_fn(model_output, model_targets.to(model_output.device)).mean(dim=(-1,-2,-3))
        if mask_empty_targets:
            mask = model_targets.sum(dim=(-1,-2,-3)) > 0
            # if the losses are one output short of the model outputs, we need to adjust the mask (backgorund is not included in the loss calculation, but is included in the model outputs)
            if per_pdb_losses.shape[1] == model_output.shape[1] - 1:
                mask = mask[:, 1:]
            # ensure that at least one pdb is included in the mask
            try:
                if mask.sum() == 0:
                    mask[0, 0] = True
            except:
                loss = per_pdb_losses[mask].mean()
        else:
            loss = per_pdb_losses.mean()

        return (loss, per_pdb_losses) if return_per_pdb_losses else loss
    
    def get_model_output(self, batch):
        _, feats = self.model.encode(batch["model_input"].unsqueeze(1).to(self.device))
        # TODO: this is messy. currently, I assume that if the target has 1 channel, we use prompting and that if it has more than 1 channel, we don't use prompting
        if batch["prompts"] is None:
            model_output = self.model.decode(feats, prompt=None)
        else:
            output = []
            prompts = batch["prompts"].to(self.device)
            for prompt_id in range(prompts.shape[1]):  # prompt is a tensor of shape (batch_size, num_prompts, prompt_dim)
                prompt = prompts[:, prompt_id]
                model_output = self.model.decode(feats, prompt=prompt)
                output.append(model_output)
            # concat along prompt dimension
            model_output = torch.concat(output, dim=1)  
        return model_output

    def training_step(self, batch, batch_idx):
        model_output = self.get_model_output(batch)
        loss = self.get_loss(
            model_output=model_output, 
            model_targets=batch["model_targets"], 
            mode="train"
        )
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        model_outputs = self.get_model_output(batch)        
        loss, per_pdb_losses = self.get_loss(
            model_output=model_outputs, 
            model_targets=batch["model_targets"], 
            mode="val",
            return_per_pdb_losses=True
        )
        # EVERYTHING BELOW IS JUST LOGGING
        # get name of dataset corresponding to dataloader_idx for logging
        dataset = self.trainer.val_dataloaders[dataloader_idx].dataset
        dataset_name = dataset.name if hasattr(dataset, "name") else None
        log_prefix = f"val_loss{f'_{dataset_name}' if dataset_name is not None else ''}"
        # log mean los
        self.log(f"{log_prefix}/mean", loss, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        # log per pdb loss
        unique_pdbs = set(np.array(batch["pdbs"]).flatten())
        per_pdb_loss_dict = {pdb: [] for pdb in unique_pdbs}
        per_pdb_model_output_dict = {pdb: [] for pdb in unique_pdbs}
        per_pdb_model_target_dict = {pdb: [] for pdb in unique_pdbs}
        for pdb_id in range(len(batch["pdbs"])):
            for batch_id in range(len(batch["pdbs"][0])):
                pdb = batch["pdbs"][pdb_id][batch_id]
                per_pdb_loss_dict[pdb].append(per_pdb_losses[batch_id][pdb_id].item())
                per_pdb_model_output_dict[pdb].append(model_outputs[batch_id][pdb_id])
                per_pdb_model_target_dict[pdb].append(batch["model_targets"][batch_id][pdb_id])
        
        # log some example outputs adn targets
        if dataloader_idx == 1:
            if batch_idx % 3 == 0:
                for pdb in unique_pdbs:
                    model_output = per_pdb_model_output_dict[pdb][0].squeeze().cpu()
                    model_target = per_pdb_model_target_dict[pdb][0].squeeze().cpu()
                    all_targets = torch.zeros_like(model_target)
                    for pdb_ in per_pdb_model_target_dict.keys():
                        if pdb_ == "background":
                            continue
                        all_targets += per_pdb_model_target_dict[pdb_][0].squeeze().cpu()
                    all_targets = all_targets > 0

                    fig, ax = plt.subplots(1, 3, figsize=(9, 3))
                    fig.suptitle(f"pdb={pdb}, loss={per_pdb_loss_dict[pdb][0]}")
                    ax[0].imshow(model_output.sum(0), vmin=0, vmax=1)
                    ax[1].imshow(model_target.sum(0), vmin=0, vmax=1)
                    ax[2].imshow(all_targets.sum(0), vmin=0, vmax=1)
                    ax[0].set_title("Output")
                    ax[1].set_title("Target")
                    ax[2].set_title("Sum of all pdb targets")
                    self.logger.experiment.add_figure(f"{log_prefix}_output/{pdb}_{batch_idx}", fig)    

        per_pdb_loss_dict = {k: sum(v) / len(v) for k, v in per_pdb_loss_dict.items()}
        for pdb, pdb_loss in per_pdb_loss_dict.items():        
            self.log(f"{log_prefix}/{pdb}", pdb_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True)           
        return loss
        
    def configure_optimizers(self):
        print(f"Configuring optimizer {self.optimizer_args['class']} with args {self.optimizer_args['class_args']}")
        optimizer = eval(self.optimizer_args["class"])(self.parameters(), **self.optimizer_args["class_args"])
        return optimizer
    
#%%
if __name__ == "__main__":
    seed_everything(187)


    augmentation_pipeline = AugmentationPipeline(
        [eval(args["class"])(**args["class_args"]) for args in augmentation_args],
        seed=int(1e9)
    )

    data_module = PromptablePickingDatamodule(
        augmentation_pipeline = augmentation_pipeline,
        **datamodule_args["class_args"],
    )
    
    # this extracts subtomos for training and saves them to disk, it also generates prompts for all particles using tomotwin
    data_module.prepare_data(**datamodule_args["prepare_data_args"])
    data_module.setup()

    # if we want to continue training from a checkpoint, we need to load the model and optimizer from the checkpoint
    if "ckpt" in model_args.keys():
        model_ = model_args["model"]
        hparams = model_.hparams
        model = MyModel(
            model_args = hparams["model_args"],
            optimizer_args = optimizer_args,
            train_loss_args = train_loss_args,
            val_loss_args = val_loss_args,
        )
        if not coldstart:
            model.model = model_.model
        model_ = None

    else:
        model = MyModel(
            model_args=model_args, 
            optimizer_args=optimizer_args,
            train_loss_args=train_loss_args,
            val_loss_args=val_loss_args,
        )

    # callback to periodically save latest model
    epoch_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"{logger.save_dir}/{logger.name}/version_{logger.version}/checkpoints",
        filename="{epoch}",
        monitor="epoch",
        mode="max",
        verbose=True,
        save_top_k=1,
        every_n_epochs=1,
        save_on_train_epoch_end=True,
    )
    val_loss_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"{logger.save_dir}/{logger.name}/version_{logger.version}/checkpoints/val_loss",
        filename="min_val_loss_{epoch}",
        monitor="val_loss/mean/dataloader_idx_0",
        verbose=True,
        mode="min",
        save_top_k=1,
        every_n_epochs=1,
    )
    exclusive_val_loss_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"{logger.save_dir}/{logger.name}/version_{logger.version}/checkpoints/val_loss",
        filename="min_exclusive_val_loss_{epoch}",
        monitor="val_loss_exclusive/mean/dataloader_idx_1",
        verbose=True,
        mode="min",
        save_top_k=1,
        every_n_epochs=1,
    )
    trainer = pl.Trainer(
        num_sanity_val_steps=0,
        max_epochs=max_epochs,
        #limit_train_batches=1,
        #limit_val_batches=1,
        gpus=gpus,
        logger=logger,
        check_val_every_n_epoch=1,
        callbacks=[epoch_callback, val_loss_callback, exclusive_val_loss_callback],
        detect_anomaly=True,
    )
    # Train the model
    if "ckpt" in model_args.keys():
        trainer.validate(model=model, dataloaders=data_module.val_dataloader())
    trainer.fit(model, train_dataloaders=data_module.train_dataloader(), val_dataloaders=data_module.val_dataloader())   