import os
from typing import Dict, List, Literal

import lightning as L
import matplotlib.pyplot as plt
import torch
import wandb
from jaxtyping import Bool, Float, Int
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from loguru import logger
from sklearn.decomposition import PCA
from torch import Tensor

from proteinfoundation.partial_autoencoder.decoder import DecoderTransformer
from proteinfoundation.partial_autoencoder.decoder_ff import DecoderFFLocal
from proteinfoundation.partial_autoencoder.encoder import EncoderTransformer
from proteinfoundation.utils.coors_utils import nm_to_ang
from proteinfoundation.utils.pdb_utils import write_prot_to_pdb

COLORS_RT = [
    "#FF0000",
    "#008000",
    "#0000FF",
    "#FFFF00",
    "#FFA500",
    "#800080",
    "#00FFFF",
    "#FF00FF",
    "#00FF00",
    "#FFC0CB",
    "#008080",
    "#E6E6FA",
    "#A52A2A",
    "#F5F5DC",
    "#800000",
    "#808000",
    "#FF7F50",
    "#000080",
    "#AAF0D1",
    "#FFDB58",
]


@rank_zero_only
def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)


class AutoEncoder(L.LightningModule):
    def __init__(self, cfg_ae, store_dir=None):
        super().__init__()
        self.save_hyperparameters()
        self.cfg_ae = cfg_ae

        self.store_dir = store_dir if store_dir is not None else "./tmp_ae"
        self.val_path_tmp = os.path.join(self.store_dir, "val_stuff")
        create_dir(self.val_path_tmp)

        self.encoder = EncoderTransformer(**self.cfg_ae.nn_ae)

        decoder_type = self.cfg_ae.nn_ae.decoder.get("type", "transformer")
        if decoder_type == "transformer":
            self.decoder = DecoderTransformer(**self.cfg_ae.nn_ae)
        elif decoder_type == "ff_local":
            self.decoder = DecoderFFLocal(**self.cfg_ae.nn_ae)
        else:
            raise IOError(f"Invalid decoder_type {decoder_type}")

        self.nsamples_processed = 0
        self.nparams_enc = sum(
            p.numel() for p in self.encoder.parameters() if p.requires_grad
        )
        self.nparams_dec = sum(
            p.numel() for p in self.decoder.parameters() if p.requires_grad
        )

        self.validation_output = []
        self.validation_data_samples = []
        self.validation_rec_samples = []

        self.latent_dim = self.cfg_ae.nn_ae["latent_z_dim"]

    def log_histogram(self, id_log: str, v: Float[torch.Tensor, "r"]):

        assert v.ndim == 1, f"Tensor v has shape {v.shape}, cannot log histogram"
        try:
            self.logger.experiment.log(
                {id_log: wandb.Histogram(v.cpu().detach().numpy())}
            )
        except:
            pass

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.cfg_ae.opt.lr,
            amsgrad=True,
            weight_decay=1e-2,
        )
        return optimizer

    def on_save_checkpoint(self, checkpoint):

        checkpoint["nsamples_processed"] = self.nsamples_processed

    def on_load_checkpoint(self, checkpoint):

        try:
            self.nsamples_processed = checkpoint["nsamples_processed"]
        except:
            logger.info("Failed to load nsamples_processed from checkpoint")
            self.nsamples_processed = 0

    def encode(self, batch: Dict) -> Float[torch.Tensor, "b n d"]:

        mask = batch["mask_dict"]["coords"][..., 0, 0]
        batch["mask"] = mask
        return self.encoder(batch)

    def decode(
        self,
        z_latent: Float[torch.Tensor, "b n d"],
        ca_coors_nm: Float[torch.Tensor, "b n 3"],
        mask: Bool[torch.Tensor, "b n"],
    ) -> Dict:

        input_decoder = {
            "z_latent": z_latent,
            "ca_coors_nm": ca_coors_nm,
            "residue_mask": mask,
            "mask": mask,
        }
        output_dec = self.decoder(input_decoder)
        mask = output_dec["residue_mask"]
        atom_mask = output_dec["atom_mask"]
        coors_nm = output_dec["coors_nm"] * mask[..., None, None] * atom_mask[..., None]
        return {
            "coors_nm": coors_nm,
            "residue_type": output_dec["aatype_max"] * mask,
            "residue_mask": mask,
            "atom_mask": atom_mask,
        }

    def training_step(self, batch: Dict, batch_idx: int):

        val_step = batch_idx == -1
        log_prefix = "validation_loss" if val_step else "train"
        histogram_every_n = 5000
        pca_every_n = 5000
        per_aatype_kl = True

        mask = batch["mask_dict"]["coords"][..., 0, 0]
        batch["mask"] = mask
        ca_coors_nm = batch["coords_nm"][..., 1, :]
        ca_coors_nm = ca_coors_nm * mask[..., None]
        bs, n = mask.shape[0], mask.shape[1]

        output_enc = self.encoder(batch)

        log_prefix_stats = (
            "train_stats_latent" if "train" in log_prefix else "val_stats_latent"
        )
        self.log_tensor_statistics(
            bs=bs,
            v=output_enc["mean"],
            log_prefix=log_prefix_stats + "_mean_LS",
            mask=mask,
            histogram_every_n=histogram_every_n,
        )
        self.log_tensor_statistics(
            bs=bs,
            v=torch.exp(output_enc["log_scale"]),
            log_prefix=log_prefix_stats + "_scale_LS",
            mask=mask,
            histogram_every_n=histogram_every_n,
        )
        self.log_tensor_statistics(
            bs=bs,
            v=output_enc["z_latent"],
            log_prefix=log_prefix_stats + "_z_LS",
            mask=mask,
            histogram_every_n=histogram_every_n,
        )
        self.log_pca(
            v=output_enc["z_latent"],
            log_prefix=log_prefix_stats + "_z_PCA",
            mask=mask,
            every_n=pca_every_n,
        )
        self.log_pca_per_residue_type(
            v=output_enc["z_latent"],
            log_prefix=log_prefix_stats + "_z_PCA_pre_aatype",
            mask=mask,
            every_n=pca_every_n,
            res_ty=batch["residue_type"],
        )

        input_decoder = {
            "z_latent": output_enc["z_latent"],
            "ca_coors_nm": ca_coors_nm,
            "residue_mask": mask,
            "mask": mask,
        }
        output_dec = self.decoder(input_decoder)

        losses = {}

        f = (
            min(1.0, self.global_step / self.cfg_ae.loss.kl.patience)
            if self.cfg_ae.loss.kl.anneal
            else 1.0
        )
        self.log(
            "kl_weight",
            self.cfg_ae.loss.kl.weight * f,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
            sync_dist=True,
            add_dataloader_idx=False,
        )
        losses.update(
            self.compute_kl_penalty(
                mean=output_enc["mean"],
                log_scale=output_enc["log_scale"],
                mask=mask,
                w=self.cfg_ae.loss.kl.weight * f,
            )
        )

        losses.update(
            self.compute_struct_rec_loss(
                output_dec=output_dec,
                batch=batch,
                reduce_mode="sum",
                loss_ty=self.cfg_ae.loss.struct.type,
                weight=self.cfg_ae.loss.struct.weight,
            )
        )

        losses.update(
            self.compute_seq_rec_loss(
                output_dec=output_dec,
                batch=batch,
                weight=self.cfg_ae.loss.seq.weight,
            )
        )

        self.log_losses(bs, losses, log_prefix)
        train_loss = sum([torch.mean(losses[k]) for k in losses if "_justlog" not in k])

        self.log(
            f"{log_prefix}/loss",
            train_loss,
            on_step=True,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=bs,
            sync_dist=True,
            add_dataloader_idx=False,
        )

        if not val_step:

            componentwise_kl = self._per_component_kl(
                mean=output_enc["mean"],
                log_scale=output_enc["log_scale"],
                mask=mask,
            )
            self.log_tensor_statistics(
                bs=bs,
                v=componentwise_kl,
                log_prefix=log_prefix + "_kl_latent",
                mask=mask,
                histogram_every_n=histogram_every_n,
            )
            self.log_tensor_statistics(
                bs=bs,
                v=(componentwise_kl > 0.1) * 1.0,
                log_prefix=log_prefix + "_kl_latent_active_thresh_0p1",
                mask=mask,
                do_hist=False,
                stats_to_log=["mean"],
            )

            if per_aatype_kl and self.global_step % histogram_every_n == 2000:
                for i in range(20):
                    self.log_tensor_statistics(
                        bs=bs,
                        v=(componentwise_kl > 0.1) * 1.0,
                        log_prefix=log_prefix_stats
                        + "_kl_latent_active_thresh_0p1_per_aatype",
                        mask=mask * (batch["residue_type"] == i),
                        do_hist=False,
                        suffix=f"_aa_{i}",
                        stats_to_log=["mean"],
                    )

            self.log_train_loss_n_prog_bar(bs, train_loss)
            self.update_n_log_nsamples_processed(bs)
            self.log_nparams()

        if val_step:
            return train_loss, output_dec
        return train_loss

    def compute_struct_rec_loss(
        self,
        output_dec: Dict[str, torch.Tensor],
        batch: Dict[str, torch.Tensor],
        reduce_mode: Literal["sum", "mean"] = "sum",
        loss_ty: str = "l2",
        weight: float = 1.0,
    ) -> Dict[str, Float[Tensor, "b"]]:

        def reduce_37(
            err: Float[torch.Tensor, "b n 37 3"],
            mask: Bool[torch.Tensor, "b n"],
            atom_mask: Bool[torch.Tensor, "b n 37"],
            mode: Literal["sum", "mean"] = "sum",
        ) -> Float[torch.Tensor, "b"]:
            nres = mask.sum(dim=-1)
            nat = atom_mask.sum(dim=-1) * mask
            err = torch.sum(err, dim=(-1, -2))
            if mode == "mean":
                err = err / nat
            err = err.sum(dim=-1) / nres
            return err

        coors_nm_pred = output_dec["coors_nm"]
        coors_nm_true = batch["coords_nm"]
        mask = output_dec["residue_mask"]
        atom_mask_true = batch["coord_mask"] * mask[..., None]

        err = coors_nm_true - coors_nm_pred
        err = err * mask[..., None, None] * atom_mask_true[..., None]

        losses = {}

        err_ang = nm_to_ang(err)
        err_ang = torch.linalg.norm(err_ang, dim=-1) ** 2
        err_ang = err_ang * mask[..., None] * atom_mask_true
        nat = atom_mask_true.sum((-1, -2))
        rmsd = torch.sqrt(torch.sum(err_ang, dim=(-1, -2)) / nat)
        losses["rmsd_no_align_a37_ang_justlog"] = rmsd

        err_l1 = reduce_37(torch.abs(err), mask, atom_mask_true, mode=reduce_mode)
        err_l2 = reduce_37(err**2, mask, atom_mask_true, mode=reduce_mode)

        if loss_ty == "l1":
            losses["struct_l1"] = err_l1 * weight
            losses["struct_l2_justlog"] = err_l2 * weight
        elif loss_ty == "l2":
            losses["struct_l1_justlog"] = err_l1 * weight
            losses["struct_l2"] = err_l2 * weight
        elif loss_ty == "l12":
            losses["struct_l1"] = err_l1 * weight
            losses["struct_l2"] = err_l2 * weight
        else:
            raise IOError(f"Loss type {loss_ty} not recognized")
        return losses

    def compute_seq_rec_loss(
        self,
        output_dec: Dict[str, torch.Tensor],
        batch: Dict[str, torch.Tensor],
        weight: float = 1.0,
    ) -> Dict[str, Float[Tensor, "b"]]:

        assert (
            "residue_type" in batch
        ), "`residue_type` not in batch, failed in compute_seq_rec_loss"
        mask = output_dec["residue_mask"]
        nres = mask.sum(dim=-1)
        logits_pred = output_dec["seq_logits"]
        target_aa = batch["residue_type"]
        target_aa = target_aa * mask

        assert logits_pred.shape[-1] == 20, "Wrong number of logits"

        b, n = mask.shape[0], mask.shape[1]
        logits_pred_flat = logits_pred.view(b * n, 20)
        target_aa_flat = target_aa.view(b * n)
        seq_loss_flat = torch.nn.functional.cross_entropy(
            input=logits_pred_flat,
            target=target_aa_flat,
            reduction="none",
        )
        seq_loss = seq_loss_flat.view(b, n)
        seq_loss = seq_loss * mask
        seq_loss = torch.sum(seq_loss, dim=-1) / nres

        pred_aa = output_dec["aatype_max"]
        seq_rec = pred_aa == target_aa
        seq_rec = seq_rec * mask
        seq_rec_rate = seq_rec.sum(dim=-1) / nres
        return {
            "ce_seq": seq_loss * weight,
            "ce_seq_now": seq_loss,
            "seq_rec_rate_justlog": seq_rec_rate,
        }

    def _per_component_kl(
        self,
        mean: Float[Tensor, "b n d"],
        log_scale: Float[Tensor, "b n d"],
        mask: Bool[Tensor, "b n"],
    ) -> Float[Tensor, "b n d"]:

        scale = torch.exp(log_scale)
        kl_div = (scale**2 + mean**2 - 1.0 - 2.0 * log_scale) * 0.5
        return kl_div * mask[..., None]

    def compute_kl_penalty(
        self,
        mean: Float[Tensor, "b n d"],
        log_scale: Float[Tensor, "b n d"],
        mask: Bool[Tensor, "b n"],
        w: float,
    ) -> Dict[str, Float[Tensor, "b"]]:

        nres = torch.sum(mask, dim=-1)
        kl_div = self._per_component_kl(
            mean=mean,
            log_scale=log_scale,
            mask=mask,
        )
        kl_div = torch.sum(kl_div, dim=(-1, -2))
        kl_div = kl_div / nres
        loss = {
            "kl_w": kl_div * w,
            "kl_now_justlog": kl_div,
        }
        return loss

    def log_tensor_statistics(
        self,
        bs: int,
        v: Float[torch.Tensor, "b n d"],
        log_prefix: str,
        mask: Bool[torch.Tensor, "b n"],
        do_stats: bool = True,
        do_hist: bool = True,
        suffix: str = "",
        histogram_every_n: int = 3000,
        stats_to_log: List[str] = ["max", "min", "mean", "median", "std"],
    ) -> None:

        vals = v.clone()[mask]
        vals = torch.flatten(vals)
        if do_stats:
            funs = {
                "max": torch.max,
                "min": torch.min,
                "mean": torch.mean,
                "median": torch.median,
                "std": torch.std,
            }
            for k in stats_to_log:
                self.log(
                    f"{log_prefix}/{k}{suffix}",
                    funs[k](vals),
                    on_step=True,
                    on_epoch=True,
                    prog_bar=False,
                    logger=True,
                    batch_size=bs,
                    sync_dist=True,
                    add_dataloader_idx=False,
                )

        if self.global_step % histogram_every_n == 0 and do_hist:
            self.log_histogram(id_log=f"{log_prefix}/histogram{suffix}", v=vals)

    def log_pca(
        self,
        v: Float[torch.Tensor, "b n d"],
        log_prefix: str,
        mask: Bool[torch.Tensor, "b n"],
        every_n: int = 3000,
    ) -> None:

        def _log_scatter(vals_x, vals_y, xlabel, ylabel, log_id):
            fig, ax = plt.subplots()
            ax.scatter(vals_x, vals_y)
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.set_title(f"{xlabel} - {ylabel} - Step: {self.trainer.global_step}")
            try:
                self.logger.experiment.log({log_id: wandb.Image(fig)})
            except:
                pass
            plt.close("all")

        if self.global_step % every_n != 0:
            return

        n_components = 4
        vals = v.clone()[mask].cpu().detach().float().numpy()
        vals_pca = PCA(n_components=n_components).fit_transform(vals)

        for i in range(n_components):
            for j in range(n_components):
                if j <= i:
                    continue
                _log_scatter(
                    vals_x=vals_pca[:, i],
                    vals_y=vals_pca[:, j],
                    xlabel=f"PC {i}",
                    ylabel=f"PC {j}",
                    log_id=f"{log_prefix}/{i}_{j}",
                )

    def log_pca_per_residue_type(
        self,
        v: Float[torch.Tensor, "b n d"],
        res_ty: Int[torch.Tensor, "b n"],
        log_prefix: str,
        mask: Bool[torch.Tensor, "b n"],
        every_n: int = 3000,
    ) -> None:

        return None

    def log_losses(
        self, bs: int, losses: Dict[str, Float[torch.Tensor, "b"]], log_prefix: str
    ) -> None:
        for k in losses:
            log_name = k[: -len("_justlog")] if k.endswith("_justlog") else k
            self.log(
                f"{log_prefix}/loss_{log_name}",
                torch.mean(losses[k]),
                on_step=True,
                on_epoch=True,
                prog_bar=False,
                logger=True,
                batch_size=bs,
                sync_dist=True,
                add_dataloader_idx=False,
            )

    def log_train_loss_n_prog_bar(self, b: int, train_loss: torch.Tensor) -> None:
        self.log(
            f"train_loss",
            train_loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            batch_size=b,
            sync_dist=True,
            add_dataloader_idx=False,
        )

    def log_nparams(self):
        self.log(
            "scaling/nparams_enc",
            self.nparams_enc * 1.0,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )

        self.log(
            "scaling/nparams_dec",
            self.nparams_dec * 1.0,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )

    def update_n_log_nsamples_processed(self, b: int):
        self.nsamples_processed = self.nsamples_processed + b * self.trainer.world_size
        self.log(
            "scaling/nsamples_processed",
            self.nsamples_processed * 1.0,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )

    def validation_step(self, batch: Dict, batch_idx: int):

        with torch.no_grad():
            bs = batch["coords_nm"].shape[0]
            loss, rec_sample = self.training_step(batch, batch_idx=-1)
            self.validation_output.append(loss.item())
            for i in range(bs):
                self.validation_data_samples.append(
                    {
                        "coors_nm": batch["coords_nm"][i, ...],
                        "aatype": batch["residue_type"][i, ...],
                        "atom_mask": batch["coord_mask"][i, ...],
                        "mask": batch["mask_dict"]["coords"][i, :, 0, 0],
                    }
                )
                self.validation_rec_samples.append(
                    {
                        "coors_nm": rec_sample["coors_nm"][i, ...],
                        "aatype": rec_sample["aatype_max"][i, ...],
                        "atom_mask": rec_sample["atom_mask"][i, ...],
                        "mask": rec_sample["residue_mask"][i, ...],
                    }
                )

    def on_validation_epoch_end(self):

        count = 0
        for sample_data, sample_rec in zip(
            self.validation_data_samples, self.validation_rec_samples
        ):
            count += 1
            if count > 10:
                break
            coors_data = (
                nm_to_ang(sample_data["coors_nm"]).float().detach().cpu().numpy()
            )
            aatype_data = sample_data["aatype"].detach().cpu().numpy()
            mask_data = sample_data["mask"].detach().cpu().numpy()
            atom_mask_data = sample_data["atom_mask"].detach().cpu().numpy()

            coors_rec = nm_to_ang(sample_rec["coors_nm"]).float().detach().cpu().numpy()
            aatype_rec = sample_rec["aatype"].detach().cpu().numpy()
            mask_rec = sample_rec["mask"].detach().cpu().numpy()
            atom_mask_rec = sample_rec["atom_mask"].detach().cpu().numpy()

            f_base = f"epoch_{self.current_epoch}_id_{count}_rank_{self.global_rank}"

            fname = f_base + "_data_sample.pdb"
            fpath = os.path.join(self.val_path_tmp, fname)
            write_prot_to_pdb(
                prot_pos=coors_data
                * mask_data[..., None, None]
                * atom_mask_data[..., None],
                file_path=fpath,
                aatype=aatype_data * mask_data,
                overwrite=True,
                no_indexing=True,
            )

            fname = f_base + "_rec_sample.pdb"
            fpath = os.path.join(self.val_path_tmp, fname)
            write_prot_to_pdb(
                prot_pos=coors_rec
                * mask_rec[..., None, None]
                * atom_mask_rec[..., None],
                file_path=fpath,
                aatype=aatype_rec * mask_rec,
                overwrite=True,
                no_indexing=True,
            )

            fname = f_base + "_rec_sample_w_true_seq.pdb"
            fpath = os.path.join(self.val_path_tmp, fname)
            write_prot_to_pdb(
                prot_pos=coors_rec
                * mask_data[..., None, None]
                * atom_mask_data[..., None],
                file_path=fpath,
                aatype=aatype_data * mask_data,
                overwrite=True,
                no_indexing=True,
            )

        self.validation_data_samples = []
        self.validation_rec_samples = []
        self.validation_output_data = []

    def predict_step(self, batch: Dict, batch_idx: int) -> Dict:

        mask = batch["mask_dict"]["coords"][..., 0, 0]
        batch["mask"] = mask
        ca_coors_nm = batch["coords_nm"][..., 1, :]
        ca_coors_nm = ca_coors_nm * mask[..., None]

        output_enc = self.encoder(batch)

        input_decoder = {
            "z_latent": output_enc["z_latent"],
            "ca_coors_nm": ca_coors_nm,
            "residue_mask": mask,
            "mask": mask,
        }
        output = self.decoder(input_decoder)

        output.update(output_enc)
        return (batch, output)
