import os
import math
import time
import json
import torch
import random
import numpy as np
from pathlib import Path
from os.path import join as pjoin
from mGPT.models.base import BaseModel
from mStream.losses.mask_dit import GPTLosses
from mGPT.config import instantiate_from_config, get_obj_from_str
from mGPT.utils.tensors import collate_tensors
from mGPT.utils.prompts import T2M_EVAL_PROMPT


class MoMask(BaseModel):
    def __init__(
        self,
        cfg,
        datamodule,
        motion_vae,
        mtr,
        motion_sr,
        padding_idx=None,
        rec_loss=False,
        rec_only=False,
        max_seq_len=256,
        **kwargs,
    ):
        self.save_hyperparameters(ignore="datamodule", logger=False)
        self.datamodule = datamodule
        super().__init__()

        self.motion_vae = instantiate_from_config(motion_vae)
        self.motion_sr = instantiate_from_config(motion_sr)
        self.mtr = instantiate_from_config(mtr)

        # Instantiate the metrics
        if "use_momask_vq" in cfg.model.params and cfg.model.params.use_momask_vq:
            self.load_momask_vq_model(cfg.model.params.use_momask_vq)

        if "use_momask_res" in cfg.model.params and cfg.model.params.use_momask_res:
            self.load_momask_res_model(cfg.model.params.use_momask_res)

        if "use_momask_t2m" in cfg.model.params and cfg.model.params.use_momask_t2m:
            self.load_momask_t2m_model(cfg.model.params.use_momask_t2m)

        if padding_idx is None:
            if hasattr(self.motion_vae, "hparams"):
                if hasattr(self.motion_vae.hparams, "nb_code"):
                    self.pad_id = self.motion_vae.hparams.nb_code
                else:
                    self.pad_id = self.motion_vae.hparams.code_num
            else:
                if hasattr(self.motion_vae, "num_code"):
                    self.pad_id = self.motion_vae.num_code
                else:
                    self.pad_id = self.motion_vae.code_num
        else:
            self.pad_id = padding_idx

        # Freeze the motion tokenizer for lm training
        if "vae" not in self.hparams.stage:
            self.motion_vae.training = False
            for p in self.motion_vae.parameters():
                p.requires_grad = False
            self.motion_sr.training = False
            for p in self.motion_sr.parameters():
                p.requires_grad = False

        # Instantiate the losses
        self._losses = torch.nn.ModuleDict(
            {
                split: GPTLosses(
                    cfg,
                    self.hparams.stage,
                    self.datamodule.njoints,
                    rec=rec_loss,
                    rec_only=rec_only,
                )
                for split in ["losses_train", "losses_test", "losses_val"]
            }
        )

        # Data transform
        self.feats2joints = datamodule.feats2joints

        self.down_t = (
            cfg.model.params.motion_vae.params.stride_t
            ** cfg.model.params.motion_vae.params.down_t
        )

        # self.manual_backward = True

    def load_momask_res_model(self, model_path):
        if isinstance(model_path, str) and os.path.exists(model_path):
            model_path = model_path
        else:
            model_path = "/mnt/momask-codes/checkpoints/t2m/tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw/model/net_best_fid.tar"

        ckpt = torch.load(model_path, map_location="cpu")
        missing_keys, unexpected_keys = self.motion_sr.load_state_dict(
            ckpt["res_transformer"], strict=False
        )

        assert len(unexpected_keys) == 0
        assert all([k.startswith("clip_model.") for k in missing_keys])

    def load_momask_t2m_model(self, model_path):
        if isinstance(model_path, str) and os.path.exists(model_path):
            model_path = model_path
        else:
            model_path = "/mnt/momask-codes/checkpoints/t2m/t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns/model/latest.tar"

        ckpt = torch.load(model_path, map_location="cpu")
        missing_keys, unexpected_keys = self.mtr.load_state_dict(
            ckpt["t2m_transformer"], strict=False
        )

        assert len(unexpected_keys) == 0
        assert all([k.startswith("clip_model.") for k in missing_keys])

    def load_momask_vq_model(self, model_path):
        if isinstance(model_path, str) and os.path.exists(model_path):
            model_path = model_path
        else:
            model_path = "/mnt/momask-codes/checkpoints/t2m/rvq_nq6_dc512_nc512_noshare_qdp0.2/model/net_best_fid.tar"

        ckpt = torch.load(model_path, map_location="cpu")
        self.motion_vae.load_state_dict(ckpt["net"], strict=True)

    def configure_optimizers(self):
        # Optimizer
        optim_target = self.hparams.cfg.TRAIN.OPTIM.target
        if len(optim_target.split(".")) == 1:
            optim_target = "torch.optim." + optim_target
        optimizer = get_obj_from_str(optim_target)(
            params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params
        )

        # Scheduler
        scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target
        if len(scheduler_target.split(".")) == 1:
            scheduler_target = "torch.optim.lr_scheduler." + scheduler_target
        if "Lambda" in scheduler_target:
            # Define the warm-up + cosine annealing scheduler
            def lr_lambda(
                epoch,
                warmup_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.warmup_epochs,
                total_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.total_epochs,
            ):
                if epoch < warmup_epochs:
                    return epoch / warmup_epochs
                else:
                    # Adjust epoch count for the cosine annealing scheduler
                    adjusted_epoch = epoch - warmup_epochs
                    # Implement cosine annealing phase, T_max adjusted for remaining epochs
                    return 0.5 * (
                        1
                        + math.cos(
                            math.pi * adjusted_epoch / (total_epochs - warmup_epochs)
                        )
                    )

            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, lr_lambda=lr_lambda
            )
        else:
            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params
            )

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def forward(self, x):
        return self.motion_sr(x)

    def motion_list2tensor(self, motion_list, first_layer=False):
        if isinstance(motion_list, list):
            motion_list = [m.squeeze(0) for m in motion_list]
            motion_list = collate_tensors(motion_list, self.pad_id)  # b n q
        else:
            motion_list = motion_list.reshape(
                motion_list.shape[0], -1, self.hparams.num_quantizers
            ).long()

        if first_layer:
            motion_list = motion_list[:, :, 0].squeeze(-1)

        return motion_list

    def motion_tensor2list(self, motion_tensor):
        return motion_tensor

    def encode_motion(self, feats_ref, lengths):
        self.motion_vae.float()
        tokens_ref = []
        lengths_token = []
        feats_ref = feats_ref.float()

        for i in range(len(feats_ref)):
            self.motion_vae.reset_buffer()
            tokens, _ = self.motion_vae.encode(feats_ref[i : i + 1, : lengths[i]])
            tokens_ref.append(tokens)
            lengths_token.append(tokens[0].squeeze(0).shape[0])

        return tokens_ref, lengths_token

    def decode_motion(self, texts, tokens_ref, feats_ref=None, lengths_ref=None):
        rst_lens = [len(t) * self.down_t for t in tokens_ref]

        if isinstance(tokens_ref, list):
            tokens_ref = [
                torch.zeros([1, 2], device=self.device).long()
                if token.shape[0] == 1
                else token
                for token in tokens_ref
            ]
            tokens_ref = self.motion_list2tensor(tokens_ref, False)

        if tokens_ref.dim() == 1:
            tokens_ref = tokens_ref.unsqueeze(-1)

        min_len = lengths_ref.copy()
        for i in range(len(texts)):
            min_len[i] = min(rst_lens[i], lengths_ref[i] * self.down_t)
        token_lens = torch.LongTensor(lengths_ref).to(self.device)

        outputs = self.motion_sr.generate(
            tokens_ref, texts, token_lens, temperature=1, cond_scale=2
        )

        if feats_ref is None:
            feats_rst = torch.zeros(
                [len(texts), self.hparams.max_seq_len, self.datamodule.nfeats]
            )
        else:
            feats_rst = torch.zeros_like(feats_ref)

        for i in range(len(outputs)):
            self.motion_vae.reset_buffer()
            outputs[i] = torch.clamp(
                outputs[i], -1, self.hparams.codebook_size - 1, out=None
            )

            motion = self.motion_vae.decode(outputs[i : i + 1])

            # Cut Motion
            feats_rst[i : i + 1, : min_len[i], ...] = motion[:, : min_len[i]]

        return feats_rst, min_len

    def training_step(self, batch, batch_idx=None):
        outputs = self.train_forward(batch)
        loss = self._losses["losses_train"].update(outputs)

        # print(loss)
        # # Backward
        # loss.backward()
        # for name, param in self.named_parameters():
        #     print(name, param.grad)

        if self._losses["losses_train"].count != 0:
            loss_total = (
                self._losses["losses_train"].total / self._losses["losses_train"].count
            )
            self.log_progress_bar({"loss": loss_total.item()})

        return loss

    def validation_step(self, batch, batch_idx=None):
        feats_ref = batch["motion"]
        lengths = batch["length"]

        tokens_ref, tokens_length = self.encode_motion(feats_ref, lengths)
        batch["tokens_ref"] = tokens_ref
        batch["tokens_length"] = tokens_length

        outputs = self.train_forward(batch)
        loss = self._losses["losses_val"].update(outputs)

        outputs.update(self.val_forward(batch, split="val"))
        self.compute_metrics(batch, outputs)

        return loss

    def test_step(self, batch, batch_idx):
        feats_ref = batch["motion"]
        lengths = batch["length"]

        tokens_ref, _ = self.encode_motion(feats_ref, lengths)
        batch["tokens_ref"] = tokens_ref

        outputs = self.val_forward(batch, split="val")
        self.compute_metrics(batch, outputs)
        return outputs

    def train_process(self, batch):
        if "labels" in batch:
            texts = batch["inputs"]
            labels = batch["labels"]
            tokens_ref = None
            lengths = None
            tasks = None
        else:
            tokens_ref = batch["motion"]
            texts = batch["text"]
            lengths = batch["length"]
            # tasks = batch["tasks"]
            tasks = None
            # all_captions = batch["all_captions"]
            labels = None

        if "tokens_ref" in batch:
            tokens_ref = batch["tokens_ref"]
        if "tokens_length" in batch:
            lengths = batch["tokens_length"]

        # if self.hparams.condition == "caption":
        #     texts = [random.choice(all_captions[i]) for i in range(len(texts))]

        tokens_ref = self.motion_list2tensor(tokens_ref, True)
        return texts, tokens_ref.long(), lengths, tasks, labels

    def train_forward(self, batch):
        texts, tokens_ref, lengths, tasks, labels = self.train_process(batch)
        token_lens = torch.LongTensor(lengths).to(self.device)

        # LLM Forward
        _loss, _pred_ids, _acc = self.mtr(tokens_ref, texts, token_lens)
        # print(tokens_ref)
        # print(token_lens)
        # print(texts)
        # print("Error in LLM forward")
        outputs = {"losses": {"pred": _loss}, "pred": _pred_ids}

        if (self.hparams.rec_loss and self.current_epoch > self.hparams.rec_loss) or (
            self.hparams.rec_only
        ):
            feats_ref, _ = self.decode_motion(texts, tokens_ref, None, lengths)
            feats_rst, _ = self.decode_motion(texts, _pred_ids, feats_ref, lengths)

            outputs.update({"m_ref": feats_ref, "m_rst": feats_rst})

        return outputs

    @torch.no_grad()
    def val_forward(self, batch, split="test"):
        feats_ref = batch["motion"]
        texts = batch["text"]
        lengths = batch["length"]
        lengths = [x // self.down_t for x in lengths]
        token_lens = torch.LongTensor(lengths).to(self.device)

        if self.trainer.datamodule.is_mm:
            texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS
            feats_ref = feats_ref.repeat_interleave(
                self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0
            )
            lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS
            instructions = pjoin(
                self.datamodule.hparams.data_root, "template_instructions.json"
            )
            instructions = json.load(open(instructions, "r"))
            # tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts)

        # Forward
        if hasattr(self.mtr, "timesteps"):
            # timesteps = self.mtr.timesteps
            timesteps = 10
        else:
            timesteps = 25
        start = time.time()
        outputs = []
        self.times = [0] * len(texts)
        from torch.nn.utils.rnn import pad_sequence

        for i in range(len(texts)):
            start = time.time()
            outputs.append(
                self.mtr.generate(
                    texts[i : i + 1],
                    token_lens[i : i + 1],
                    timesteps=timesteps,
                    cond_scale=4,
                    temperature=1,
                )[0]
            )
            end = time.time()
            self.times[i] += end - start
        outputs = pad_sequence(outputs, True)
        # outputs2 = self.motion_list2tensor(batch["tokens_ref"], True).long()

        feats_rst, min_len = self.decode_motion(texts, outputs, feats_ref, lengths)
        end = time.time()
        print(self.times)
        # Recover joints for evaluation
        joints_ref = self.feats2joints(feats_ref)
        joints_rst = self.feats2joints(feats_rst)

        # Renorm for evaluation
        feats_ref = self.datamodule.renorm4t2m(feats_ref)
        feats_rst = self.datamodule.renorm4t2m(feats_rst)

        # return set
        rs_set = {
            "m_ref": feats_ref,
            "m_rst": feats_rst,
            "joints_ref": joints_ref,
            "joints_rst": joints_rst,
            "length": min_len,
        }

        return rs_set

    @torch.no_grad()
    def compute_metrics(self, batch, rs_set):
        if self.hparams.task not in ["m2t"]:
            # MultiModality evaluation sperately
            if self.trainer.datamodule.is_mm:
                metrics_dicts = ["MMMetrics"]
            else:
                metrics_dicts = self.hparams.metrics_dict

            for metric in metrics_dicts:
                lengths = batch["length"]
                if metric == "TemosMetric":
                    getattr(self.metrics, metric).update(
                        rs_set["joints_rst"], rs_set["joints_ref"], lengths
                    )
                elif metric in ["TM2TMetrics", "MotionxMetrics"]:
                    word_embs = batch["word_embs"]
                    pos_ohot = batch["pos_ohot"]
                    text_lengths = batch["text_len"]
                    if self.trainer.datamodule.is_mm:
                        word_embs = word_embs.repeat_interleave(
                            self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0
                        )
                        pos_ohot = pos_ohot.repeat_interleave(
                            self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0
                        )
                        text_lengths = text_lengths.repeat_interleave(
                            self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0
                        )

                    getattr(self.metrics, metric).update(
                        feats_ref=rs_set["m_ref"],
                        feats_rst=rs_set["m_rst"],
                        lengths_ref=lengths,
                        lengths_rst=rs_set["length"],
                        word_embs=word_embs,
                        pos_ohot=pos_ohot,
                        text_lengths=text_lengths,
                    )
                elif metric == "TMRMetrics":
                    if self.hparams.stage in ["lm_instruct", "lm_pretrain"]:
                        texts = batch["text"]
                    else:
                        texts = ["" for _ in range(len(lengths))]
                    if self.trainer.datamodule.is_mm:
                        texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS

                    getattr(self.metrics, metric).update(
                        feats_ref=rs_set["m_ref_denorm"],
                        feats_rst=rs_set["m_rst_denorm"],
                        lengths_ref=lengths,
                        lengths_rst=rs_set["length"],
                        texts=texts,
                    )
                elif metric == "TeachMetrics":
                    if self.trainer.datamodule.is_mm:
                        texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS
                    getattr(self.metrics, metric).update(
                        feats_ref=rs_set["m_ref_denorm"],
                        feats_ref0=rs_set["m_ref0_denorm"],
                        feats_ref1=rs_set["m_ref1_denorm"],
                        feats_rst=rs_set["m_rst_denorm"],
                        feats_rst0=rs_set["m_rst0_denorm"],
                        feats_rst1=rs_set["m_rst1_denorm"],
                        lengths_ref=rs_set["length"],
                        lengths_ref0=rs_set["length0"],
                        lengths_ref1=rs_set["length1"],
                        lengths_rst=rs_set["length"],
                        lengths_rst0=rs_set["length0"],
                        lengths_rst1=rs_set["length1"],
                        texts0=rs_set["text_0"],
                        texts1=rs_set["text_1"],
                    )
                elif metric == "UncondMetrics":
                    getattr(self.metrics, metric).update(
                        recmotion_embeddings=rs_set["lat_rm"],
                        gtmotion_embeddings=rs_set["lat_m"],
                        lengths=lengths,
                    )
                elif metric in ["MRMetrics", "BedlamMetrics", "PredMetrics"]:
                    getattr(self.metrics, metric).update(
                        rs_set["joints_rst"], rs_set["joints_ref"], lengths
                    )
                elif metric == "MMMetrics":
                    # pass
                    getattr(self.metrics, metric).update(
                        rs_set["m_rst"], rs_set["length"]
                    )
                else:
                    raise TypeError(f"Not support this metric {metric}")

        elif self.hparams.task == "m2t" and self.hparams.stage in [
            "lm_instruct",
            "lm_pretrain",
            "lm_rl",
        ]:
            self.hparams.metrics_dict = metrics_dicts = ["M2TMetrics"]
            for metric in metrics_dicts:
                if metric == "M2TMetrics":
                    getattr(self.metrics, metric).update(
                        feats_ref=rs_set["m_ref"],
                        pred_texts=rs_set["t_pred"],
                        gt_texts=batch["all_captions"],
                        lengths=rs_set["length"],
                        word_embs=batch["word_embs"],
                        pos_ohot=batch["pos_ohot"],
                        text_lengths=batch["text_len"],
                    )
