import os
import math
import time
import torch
import numpy as np
from pathlib import Path
from ..base import BaseModel
from mGPT.losses.msr import GPTLosses
from mGPT.config import instantiate_from_config, get_obj_from_str
from mGPT.utils.tensors import collate_tensors
import scripts.plot_3d_global as plot_3d


class MotionSR(BaseModel):
    def __init__(
        self, cfg, datamodule, motion_vae, motion_sr, padding_idx=None, **kwargs
    ):
        self.save_hyperparameters(ignore="datamodule", logger=False)
        self.datamodule = datamodule
        super().__init__()

        self.motion_vae = instantiate_from_config(motion_vae)
        self.motion_sr = instantiate_from_config(motion_sr)

        if "use_momask_vq" in cfg.model.params and cfg.model.params.use_momask_vq:
            self.load_momask_vq_model(cfg.model.params.use_momask_vq)

        if hasattr(self.motion_vae, "hparams"):
            self.num_quantizers = self.motion_vae.hparams.num_quantizers
        else:
            if hasattr(self.motion_vae, "num_quantizers"):
                self.num_quantizers = self.motion_vae.num_quantizers
            else:
                self.num_quantizers = self.motion_vae.code_depth

        if padding_idx is None:
            if hasattr(self.motion_vae, "hparams"):
                self.pad_id = self.motion_vae.hparams.code_num
            else:
                if hasattr(self.motion_vae, "num_code"):
                    self.pad_id = self.motion_vae.num_code
                else:
                    self.pad_id = self.motion_vae.code_num
        else:
            self.pad_id = padding_idx

        # Freeze the motion tokenizer for lm training
        if "vae" not in self.hparams.stage:
            self.motion_vae.training = False
            for p in self.motion_vae.parameters():
                p.requires_grad = False

        # Instantiate the losses
        self._losses = torch.nn.ModuleDict(
            {
                split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints)
                for split in ["losses_train", "losses_test", "losses_val"]
            }
        )

        # Data transform
        self.feats2joints = datamodule.feats2joints

        self.down_t = (
            cfg.model.params.motion_vae.params.stride_t
            ** cfg.model.params.motion_vae.params.down_t
        )

    def load_momask_vq_model(self, model_path):
        if isinstance(model_path, str) and os.path.exists(model_path):
            model_path = model_path
        else:
            model_path = "/mnt/momask-codes/checkpoints/t2m/rvq_nq6_dc512_nc512_noshare_qdp0.2/model/net_best_fid.tar"

        ckpt = torch.load(model_path, map_location="cpu")
        self.motion_vae.load_state_dict(ckpt["net"], strict=True)

    def configure_optimizers(self):
        # Optimizer
        optim_target = self.hparams.cfg.TRAIN.OPTIM.target
        if len(optim_target.split(".")) == 1:
            optim_target = "torch.optim." + optim_target
        optimizer = get_obj_from_str(optim_target)(
            params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params
        )

        # Scheduler
        scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target
        if len(scheduler_target.split(".")) == 1:
            scheduler_target = "torch.optim.lr_scheduler." + scheduler_target
        if "Lambda" in scheduler_target:
            # Define the warm-up + cosine annealing scheduler
            def lr_lambda(
                epoch,
                warmup_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.warmup_epochs,
                total_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.total_epochs,
            ):
                if epoch < warmup_epochs:
                    return epoch / warmup_epochs
                else:
                    # Adjust epoch count for the cosine annealing scheduler
                    adjusted_epoch = epoch - warmup_epochs
                    # Implement cosine annealing phase, T_max adjusted for remaining epochs
                    return 0.5 * (
                        1
                        + math.cos(
                            math.pi * adjusted_epoch / (total_epochs - warmup_epochs)
                        )
                    )

            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, lr_lambda=lr_lambda
            )
        else:
            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params
            )

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def forward(self, x):
        return self.motion_sr(x)

    def encode_motion(self, feats_ref, lengths):
        tokens_ref = []
        lengths_token = []
        for i in range(len(feats_ref)):
            self.motion_vae.reset_buffer()
            tokens, _ = self.motion_vae.encode(feats_ref[i : i + 1, : lengths[i]])
            tokens_ref.append(tokens)
            lengths_token.append(tokens.squeeze(0).shape[0])
        tokens_ref = self.motion_list2tensor(tokens_ref)
        return tokens_ref, lengths_token

    def motion_list2tensor(self, motion_list, first_layer=False):
        if isinstance(motion_list, list):
            motion_list = [m.squeeze(0) for m in motion_list]
            motion_list = collate_tensors(motion_list, self.pad_id)  # b n q
        else:
            motion_list = motion_list.reshape(
                motion_list.shape[0], -1, self.hparams.num_quantizers
            )

        if first_layer:
            motion_list = motion_list[:, :, :1].squeeze(-1)

        return motion_list

    def motion_tensor2list(self, motion_tensor):
        return motion_tensor

    def training_step(self, batch, batch_idx=None):
        outputs = self.train_msr_forward(batch)
        loss = self._losses["losses_train"].update(outputs)
        if self._losses["losses_train"].count != 0:
            loss_total = (
                self._losses["losses_train"].total / self._losses["losses_train"].count
            )
            # self.log("loss", loss_total.item(), prog_bar=True)
            self.log_progress_bar({"loss": loss_total.item()})
        return loss

    def validation_step(self, batch, batch_idx=None):
        feats_ref = batch["motion"]
        lengths = batch["length"]

        tokens_ref, _ = self.encode_motion(feats_ref, lengths)
        batch["tokens_ref"] = tokens_ref

        outputs = self.train_msr_forward(batch)
        loss = self._losses["losses_val"].update(outputs)

        outputs.update(self.val_msr_forward(batch, split="val"))
        self.compute_metrics(batch, outputs)

        return loss

    def test_step(self, batch, batch_idx):
        feats_ref = batch["motion"]
        lengths = batch["length"]

        tokens_ref, _ = self.encode_motion(feats_ref, lengths)
        batch["tokens_ref"] = tokens_ref
        outputs = self.val_msr_forward(batch, split="val")
        self.compute_metrics(batch, outputs)
        return outputs

    def train_msr_forward(self, batch):
        texts = batch["text"]
        if "tokens_ref" in batch:
            tokens_ref = batch["tokens_ref"]
        else:
            tokens_ref = batch["motion"]

        if isinstance(tokens_ref, list):
            tokens_ref = self.motion_list2tensor(tokens_ref)

        lengths = batch["length"]
        tokens_ref = tokens_ref.to(self.device).long()
        token_lens = torch.LongTensor(lengths).to(self.device)
        if hasattr(self.motion_vae, "v_lengths"):
            token_lens = token_lens // self.motion_vae.hparams.v_lengths[-1]

        # Motion SR Forward
        # print(texts[:3], token_lens[:3])
        ce_loss, pred_ids, acc = self.motion_sr(tokens_ref, texts, token_lens)
        outputs = {"losses": {"pred": ce_loss}, "pred": pred_ids}
        # outputs = self.motion_sr(texts, tokens_ref, lengths, labels)
        return outputs

    @torch.no_grad()
    def val_msr_forward(self, batch, split="test"):
        feats_ref = batch["motion"]
        texts = batch["text"]
        lengths = batch["length"]
        min_len = lengths.copy()

        # tokens_ref, _ = self.encode_motion(feats_ref, lengths)
        tokens_ref = batch["tokens_ref"].long()
        token_lens = torch.LongTensor(lengths).to(self.device)
        if hasattr(self.motion_vae, "v_lengths"):
            token_lens = token_lens // self.motion_vae.hparams.v_lengths[-1]
        # Forward
        start_time = time.time()

        # outputs = self.motion_sr.generate(
        #     tokens_ref[..., 0], texts, token_lens, temperature=1, cond_scale=2
        # )
        outputs = self.motion_sr.generate(
            tokens_ref[..., 0],
            texts,
            token_lens,
            temperature=1,
            cond_scale=0,
            topk_filter_thres=0.8,
        )
        # outputs = tokens_ref
        # outputs = self.postprocess_tokens(tokens_ref)
        outputs = self.motion_tensor2list(outputs)

        feats_rst = torch.zeros_like(feats_ref)

        # Motion Decode
        for i in range(len(texts)):
            self.motion_vae.reset_buffer()
            output_token = outputs[i].unsqueeze(0)
            # output_token = torch.clamp(outputs[i], min=0, max=self.motion_vae.hparams.code_num - 1, out=None).unsqueeze(0)

            motion = self.motion_vae.decode(output_token)
            min_len[i] = min(motion.shape[1], lengths[i])
            feats_rst[i : i + 1, : min_len[i], ...] = motion[:, : lengths[i]]

        # feats_rst = feats_ref
        endtime = time.time()
        # print(motion.shape[1]/(endtime-start_time))

        # Recover joints for evaluation
        joints_ref = self.feats2joints(feats_ref)
        joints_rst = self.feats2joints(feats_rst)

        # Renorm for evaluation
        m_ref_denorm = self.datamodule.denormalize(feats_ref)
        m_rst_denorm = self.datamodule.denormalize(feats_rst)

        # Renorm for evaluation
        feats_ref = self.datamodule.renorm4t2m(feats_ref)
        feats_rst = self.datamodule.renorm4t2m(feats_rst)

        # return set
        rs_set = {
            "m_ref": feats_ref,
            "m_rst": feats_rst,
            "joints_ref": joints_ref,
            "joints_rst": joints_rst,
            "length": min_len,
            "m_ref_denorm": m_ref_denorm,
            "m_rst_denorm": m_rst_denorm,
        }

        if self.trainer.testing:
            save_path = (
                "../memData/visualization/t2m_causal_babel/" + self.hparams.cfg.NAME
            )
            os.makedirs(save_path, exist_ok=True)
            self._plot_2d(lengths, joints_ref, joints_rst, save_path)

        return rs_set

    def _plot_2d(
        self,
        lengths,
        joints_ref,
        joints_rst=None,
        path=None,
        texts=None,
        texts_gen=None,
    ):
        path = "./results" if path == None else path

        for bid in range(len(lengths)):
            if 5 < bid < 10:
                continue
            print("Saving {}.gif".format(bid))

            try:
                xyz_ref = joints_ref[bid : bid + 1].detach().cpu().numpy()
                xyz = (
                    joints_rst[bid : bid + 1].detach().cpu().numpy()
                    if joints_rst != None
                    else None
                )
            except:
                xyz_ref = joints_ref[bid : bid + 1].detach().numpy()
                xyz = (
                    joints_rst[bid : bid + 1].detach().cpu().numpy()
                    if joints_rst != None
                    else None
                )

            plot_3d.draw_to_batch(
                xyz_ref[:, : lengths[bid], :, :],
                [""],
                [os.path.join(path, "{}_gt.gif".format(bid))],
            )

            if joints_rst != None:
                plot_3d.draw_to_batch(
                    xyz[:, : lengths[bid], :, :],
                    [""],
                    [os.path.join(path, "{}_gen.gif".format(bid))],
                )

            if texts != None:
                with open(os.path.join(path, "{}.txt".format(bid)), "w") as f:
                    f.write(texts[bid])

            if texts_gen != None:
                with open(os.path.join(path, "{}_gen.txt".format(bid)), "w") as f:
                    f.write(texts_gen[bid])
