from .ema import ExponentialMovingAverage
from .logger import get_logger
from .residue_constants import aatype_to_str_sequence

logger = get_logger(__name__)

import pytorch_lightning as pl
import torch, time, os, wandb
import numpy as np
import pandas as pd
from .rigid_utils import Rigid, Rotation
from collections import defaultdict
from functools import partial

from .model.latent_model import LatentMDGenModel
from .transport.transport_st import create_transport, Sampler
from .utils import get_offsets, atom14_to_pdb
from .tensor_utils import tensor_tree_map
from .geometry import frames_torsions_to_atom14, atom37_to_atom14


def gather_log(log, world_size):
    if world_size == 1:
        return log
    log_list = [None] * world_size
    torch.distributed.all_gather_object(log_list, log)
    log = {key: sum([l[key] for l in log_list], []) for key in log}
    return log


def get_log_mean(log):
    out = {}
    for key in log:
        try:
            out[key] = np.nanmean(log[key])
        except:
            pass
    return out


DESIGN_IDX = [1, 2]
COND_IDX = [0, 3]
DESIGN_MAP_TO_COND = [0, 0, 3, 3]


class Wrapper(pl.LightningModule):

    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters()
        self.args = args

        self._log = defaultdict(list)
        self.last_log_time = time.time()
        self.iter_step = 0

    def log(self, key, data):
        if isinstance(data, torch.Tensor):
            data = data.mean().item()
        log = self._log
        if self.stage == "train" or self.args.validate:
            log["iter_" + key].append(data)
        log[self.stage + "_" + key].append(data)

    def load_ema_weights(self):
        # model.state_dict() contains references to model weights rather
        # than copies. Therefore, we need to clone them before calling
        # load_state_dict().
        logger.info("Loading EMA weights")
        clone_param = lambda t: t.detach().clone()
        self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
        self.model.load_state_dict(self.ema.state_dict()["params"])

    def restore_cached_weights(self):
        logger.info("Restoring cached weights")
        self.model.load_state_dict(self.cached_weights)
        self.cached_weights = None

    def on_before_zero_grad(self, *args, **kwargs):
        if self.args.ema:
            self.ema.update(self.model)

    def training_step(self, batch, batch_idx):
        if self.args.ema:
            if self.ema.device != self.device:
                self.ema.to(self.device)
        return self.general_step(batch, stage="train")

    def validation_step(self, batch, batch_idx):
        if self.args.ema:
            if self.ema.device != self.device:
                self.ema.to(self.device)
            if self.cached_weights is None:
                self.load_ema_weights()

        self.general_step(batch, stage="val")
        # self.validation_step_extra(batch, batch_idx)
        if self.args.validate and self.iter_step % self.args.print_freq == 0:
            self.print_log()

    def validation_step_extra(self, batch, batch_idx):
        pass

    def on_train_epoch_end(self):
        self.print_log(prefix="train", save=False)

    def on_validation_epoch_end(self):
        if self.args.ema:
            self.restore_cached_weights()
        self.print_log(prefix="val", save=False)

    def on_before_optimizer_step(self, optimizer):
        if (self.trainer.global_step + 1) % self.args.print_freq == 0:
            self.print_log()

        if self.args.check_grad:
            for name, p in self.model.named_parameters():
                if p.grad is None:
                    logger.warning(f"Param {name} has no grad")

    def on_load_checkpoint(self, checkpoint):
        logger.info("Loading EMA state dict")
        if self.args.ema:
            ema = checkpoint["ema"]
            self.ema.load_state_dict(ema)

    def on_save_checkpoint(self, checkpoint):
        if self.args.ema:
            if self.cached_weights is not None:
                self.restore_cached_weights()
            checkpoint["ema"] = self.ema.state_dict()

    def print_log(self, prefix="iter", save=False, extra_logs=None):
        log = self._log
        log = {key: log[key] for key in log if f"{prefix}_" in key}
        log = gather_log(log, self.trainer.world_size)
        mean_log = get_log_mean(log)

        mean_log.update(
            {
                "epoch": self.trainer.current_epoch,
                "trainer_step": self.trainer.global_step + int(prefix == "iter"),
                "iter_step": self.iter_step,
                f"{prefix}_count": len(log[next(iter(log))]),
            }
        )
        if extra_logs:
            mean_log.update(extra_logs)
        try:
            for param_group in self.optimizers().optimizer.param_groups:
                mean_log["lr"] = param_group["lr"]
        except:
            pass

        if self.trainer.is_global_zero:
            logger.info(str(mean_log))
            if self.args.wandb:
                wandb.log(mean_log)
            if save:
                path = os.path.join(
                    os.environ["MODEL_DIR"],
                    f"{prefix}_{self.trainer.current_epoch}.csv",
                )
                pd.DataFrame(log).to_csv(path)
        for key in list(log.keys()):
            if f"{prefix}_" in key:
                del self._log[key]

    def configure_optimizers(self):
        cls = torch.optim.AdamW if self.args.adamW else torch.optim.Adam
        optimizer = cls(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.args.lr,
        )
        return optimizer


class NewMDGenWrapper(Wrapper):
    def __init__(self, args):
        super().__init__(args)
        for key in [
            "inpainting",
            "no_torsion",
            "hyena",
            "no_aa_emb",
            "supervise_all_torsions",
            "supervise_no_torsions",
            "design_key_frames",
            "no_design_torsion",
            "cond_interval",
            "mpnn",
            "dynamic_mpnn",
            "no_offsets",
            "no_frames",
        ]:
            if not hasattr(args, key):
                setattr(args, key, False)
        latent_dim = 21

        self.latent_dim = latent_dim
        self.model = LatentMDGenModel(args, latent_dim)

        self.transport = create_transport(
            args,
            args.path_type,
            args.prediction,
            None,
        )
        self.transport_sampler = Sampler(self.transport)
        self.offsets_from_noise = args.offsets_from_noise
        self.torsions_from_noise = args.torsions_from_noise

        if not hasattr(args, "ema"):
            args.ema = False
        if args.ema:
            self.ema = ExponentialMovingAverage(model=self.model, decay=args.ema_decay)
            self.cached_weights = None

    def prep_batch(self, batch, inference=False):
        if not inference:
            for key in ["trans", "rots", "torsions"]:
                B, T = batch[key].shape[:2]
                rest_shape = batch[key].shape[2:]
                batch[key] = torch.stack(
                    [
                        batch[key].reshape(B * T, *rest_shape),
                        batch[key + "_plus_tau"].reshape(B * T, *rest_shape),
                    ],
                    dim=1,
                )
            batch["torsion_mask"] = (
                batch["torsion_mask"]
                .unsqueeze(1)
                .expand(B, T, *batch["torsion_mask"].shape[1:])
            )
            batch["torsion_mask"] = batch["torsion_mask"].reshape(
                B * T, *batch["torsion_mask"].shape[2:]
            )
            batch["mask"] = (
                batch["mask"].unsqueeze(1).expand(B, T, *batch["mask"].shape[1:])
            )
            batch["mask"] = batch["mask"].reshape(B * T, *batch["mask"].shape[2:])
            batch["seqres"] = (
                batch["seqres"].unsqueeze(1).expand(B, T, *batch["seqres"].shape[1:])
            )
            batch["seqres"] = batch["seqres"].reshape(B * T, *batch["seqres"].shape[2:])

        # batch['mask'] = torch.ones(B*T, 2, device=batch['mask'].device)

        rigids = Rigid(
            trans=batch["trans"], rots=Rotation(rot_mats=batch["rots"])
        )  # B, T, L
        B, T, L = rigids.shape

        if self.training and (
            getattr(self.args, "conditioning_noise_trans", 0) > 0
            or getattr(self.args, "conditioning_noise_rots", 0) > 0
        ):
            noisy_start_frames = Rigid.add_noise_to_rigid(
                rigids[:, 0],
                noise_std_trans=self.args.conditioning_noise_trans,
                noise_std_rot=self.args.conditioning_noise_rots,
            )
            offsets = get_offsets(noisy_start_frames.unsqueeze(1), rigids)
        else:
            noisy_start_frames = rigids[:, 0]
            offsets = get_offsets(rigids[:, 0:1], rigids)

        #### make sure the quaternions have real part

        offsets[..., :4] *= torch.where(offsets[:, :, :, 0:1] < 0, -1, 1)

        frame_loss_mask = batch["mask"].unsqueeze(-1).expand(-1, -1, 7)  # B, L, 7

        torsion_loss_mask = (
            batch["torsion_mask"].unsqueeze(-1).expand(-1, -1, -1, 2).reshape(B, L, 14)
        )
        # latents = torch.cat([offsets, batch["torsions"].view(B, T, L, 14)], -1)

        # loss_mask = torch.cat([frame_loss_mask, torsion_loss_mask], -1)
        # loss_mask = loss_mask.unsqueeze(1).expand(-1, T, -1, -1)
        # loss_mask[:, 0, :, :] = 0

        ########
        cond_mask = torch.zeros(B, T, L, dtype=int, device=offsets.device)
        cond_mask[:, 0] = 1

        torsions = batch["torsions"].view(B, T, L, 14)

        if self.training and getattr(self.args, "conditioning_noise_torsion", 0) > 0:
            noise_tors = (
                torch.randn_like(torsions) * self.args.conditioning_noise_torsion
            )
            torsions = torsions + noise_tors * cond_mask.unsqueeze(-1).float()

        latents = torch.cat([offsets, torsions], -1)

        loss_mask = torch.cat([frame_loss_mask, torsion_loss_mask], -1)
        loss_mask = loss_mask.unsqueeze(1).expand(-1, T, -1, -1)

        aatype_mask = torch.ones_like(batch["seqres"])

        return {
            "rigids": rigids,
            "latents": latents,
            "loss_mask": loss_mask,
            "model_kwargs": {
                "start_frames": noisy_start_frames,
                "end_frames": rigids[:, -1],
                "mask": batch["mask"].unsqueeze(1).expand(-1, T, -1),
                "aatype": torch.where(aatype_mask.bool(), batch["seqres"], 20),
                "x_cond": torch.where(cond_mask.unsqueeze(-1).bool(), latents, 0.0),
                "x_cond_mask": cond_mask,
            },
        }

    def general_step(self, batch, stage="train"):
        self.iter_step += 1
        self.stage = stage
        start1 = time.time()

        prep = self.prep_batch(batch)
        # print("prep latents", torch.norm(prep['latents'][..., :4], dim = -1))

        start = time.time()
        out_dict = self.transport.training_losses(
            model=self.model,
            x1=prep["latents"],
            aatype1=batch["seqres"] if self.args.design else None,
            mask=prep["loss_mask"],
            model_kwargs=prep["model_kwargs"],
            offsets_from_noise=self.offsets_from_noise,
            torsions_from_noise=self.torsions_from_noise,
            noise_OT=self.args.noise_OT,
        )
        self.log("model_dur", time.time() - start)
        loss = out_dict["loss"]
        self.log("loss", loss)
        if self.args.design:
            aa_out = torch.argmax(out_dict["logits"], dim=-1)
            aa_recovery = aa_out == batch["seqres"][:, None, :].expand(
                -1, aa_out.shape[1], -1
            )

            self.log(
                "category_pred_design_aa_recovery",
                aa_recovery[:, :, 1:-1].float().mean().item(),
            )
            cond_aa_recovery = torch.cat(
                [aa_recovery[:, :, 0:1], aa_recovery[:, :, -1:]], 2
            )
            self.log(
                "category_pred_cond_aa_recovery", cond_aa_recovery.float().mean().item()
            )

            self.log("loss_continuous", out_dict["loss_continuous"].mean())
            self.log("loss_discrete", out_dict["loss_discrete"])

        self.log("time", out_dict["t"])
        self.log("dur", time.time() - self.last_log_time)
        if "name" in batch:
            self.log("name", ",".join(batch["name"]))
        self.log("general_step_dur", time.time() - start1)
        self.last_log_time = time.time()
        return loss.mean()

    def inference(self, batch):

        prep = self.prep_batch(batch, inference=True)

        latents = prep["latents"]
        rigids = prep["rigids"]
        B, T, L = rigids.shape

        zs = torch.randn(B, T, L, self.latent_dim, device=self.device)

        sample_fn = self.transport_sampler.sample_ode(
            sampling_method=self.args.sampling_method
        )

        if not self.torsions_from_noise:
            zs[:, :, :, 7:] = latents[:, 0, :, 7:].unsqueeze(1)

        if not self.offsets_from_noise:
            zs[:, :, :, :7] = torch.zeros_like(zs[:, :, :, :7])
            zs[:, :, :, 0] = 1.0

        samples = sample_fn(
            zs, partial(self.model.forward_inference, **prep["model_kwargs"])
        )[-1]

        offsets = samples[..., :7]

        torsions = samples[..., 7:21]

        frames = rigids[:, 0:1].compose(
            Rigid.from_tensor_7(offsets, normalize_quats=True)
        )
        torsions = torsions.reshape(B, T, L, 7, 2)
        atom14 = frames_torsions_to_atom14(
            frames,
            torsions.view(B, T, L, 7, 2),
            batch["seqres"][:, None].expand(B, T, L),
        )

        aa_out = batch["seqres"][:, None].expand(B, T, L)
        return atom14, aa_out

    def validation_step_extra(self, batch, batch_idx):
        print("validation step extra implemented")
        do_designability = (
            batch_idx < self.args.inference_batches
            and (
                (self.current_epoch + 1) % self.args.designability_freq == 0
                or self.args.validate
            )
            and self.trainer.is_global_zero
        )
        if do_designability:
            atom14, aa_out = self.inference(batch)
            aa_recovery = aa_out == batch["seqres"][:, None, :].expand(
                -1, aa_out.shape[1], -1
            )
            self.log(
                "design_aa_recovery", aa_recovery[:, :, 1:-1].float().mean().item()
            )
            cond_aa_recovery = torch.cat(
                [aa_recovery[:, :, 0:1], aa_recovery[:, :, -1:]], 2
            )
            self.log("cond_aa_recovery", cond_aa_recovery.float().mean().item())
            self.log(
                "seq_pred",
                ",".join([aatype_to_str_sequence(aa) for aa in aa_out[:, 0]]),
            )
            self.log(
                "seq_true",
                ",".join([aatype_to_str_sequence(aa) for aa in batch["seqres"]]),
            )
            prot_name = batch["name"][0]
            path = os.path.join(
                os.environ["MODEL_DIR"], f"epoch{self.current_epoch}_{prot_name}.pdb"
            )

            atom14_to_pdb(
                atom14[0].cpu().numpy(), batch["seqres"][0].cpu().numpy(), path
            )
        else:
            self.log("design_aa_recovery", np.nan)
            self.log("cond_aa_recovery", np.nan)
            self.log("seq_pred", "nan")
            self.log("seq_true", "nan")
