import os
import math
import time
import json
import torch
import numpy as np
from torch import nn
from pathlib import Path
from os.path import join as pjoin
from .base import BaseModel
from mGPT.losses.msr import GPTLosses
from mGPT.config import instantiate_from_config, get_obj_from_str
from mGPT.utils.tensors import collate_tensors
import scripts.plot_3d_global as plot_3d


class MotionVQ(BaseModel):
    def __init__(self, cfg, datamodule, motion_vae, **kwargs):
        self.save_hyperparameters(ignore="datamodule", logger=False)
        self.datamodule = datamodule
        super().__init__()

        self.motion_vae = instantiate_from_config(motion_vae)

        self.num_quantizers = self.motion_vae.hparams.num_quantizers
        self.hparams.codebook_size = self.motion_vae.hparams.code_num
        # if self.motion_vae.hparams.regroup:
        #     self.hparams.codebook_size = self.num_quantizers * self.motion_vae.hparams.code_num

        # Instantiate the losses
        self._losses = torch.nn.ModuleDict(
            {
                split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints)
                for split in ["losses_train", "losses_test", "losses_val"]
            }
        )

        # Data transform
        self.feats2joints = datamodule.feats2joints

        # Count codebook frequency
        self.codePred = []
        self.codeFrequency = torch.zeros((self.hparams.codebook_size,))

        self.down_t = (
            cfg.model.params.motion_vae.params.stride_t
            ** cfg.model.params.motion_vae.params.down_t
        )

        # self.manual_backward = True

    def configure_optimizers(self):
        # Optimizer
        optim_target = self.hparams.cfg.TRAIN.OPTIM.target
        if len(optim_target.split(".")) == 1:
            optim_target = "torch.optim." + optim_target
        optimizer = get_obj_from_str(optim_target)(
            params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params
        )

        # Scheduler

        scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target
        if len(scheduler_target.split(".")) == 1:
            scheduler_target = "torch.optim.lr_scheduler." + scheduler_target
        if "Lambda" in scheduler_target:
            # Define the warm-up + cosine annealing scheduler
            def lr_lambda(
                epoch,
                warmup_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.warmup_epochs,
                total_epochs=self.hparams.cfg.TRAIN.LR_SCHEDULER.params.total_epochs,
            ):
                if epoch < warmup_epochs:
                    return epoch / warmup_epochs
                else:
                    # Adjust epoch count for the cosine annealing scheduler
                    adjusted_epoch = epoch - warmup_epochs
                    # Implement cosine annealing phase, T_max adjusted for remaining epochs
                    return 0.5 * (
                        1
                        + math.cos(
                            math.pi * adjusted_epoch / (total_epochs - warmup_epochs)
                        )
                    )

            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, lr_lambda=lr_lambda
            )
        else:
            lr_scheduler = get_obj_from_str(scheduler_target)(
                optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params
            )

        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def forward(self, x):
        return self.motion_vae(x)

    def training_step(self, batch: dict, batch_idx: int = None):
        outputs = self.train_forward(batch)
        loss = self._losses["losses_train"].update(outputs)
        self.log("loss", loss.item(), prog_bar=True, sync_dist=True)

        # loss.backward()
        # # Print parameters without grad
        # for name, param in self.motion_vae.named_parameters():
        #     if param.grad is None:
        #         print(name)

        return loss

    def validation_step(self, batch: dict, batch_idx: int = None):
        outputs = self.train_forward(batch)
        loss = self._losses["losses_val"].update(outputs)
        outputs = {}

        outputs.update(self.val_forward(batch, split="val"))
        self.compute_metrics(batch, outputs)

        return loss

    def test_step(self, batch: dict, batch_idx: int = None):
        outputs = self.val_forward(batch, split="test")
        self.compute_metrics(batch, outputs)
        return outputs

    def train_forward(self, batch: dict):
        """
        Forward pass for training
        """

        feats_ref = batch["motion"]
        lengths_ref = batch["length"]

        # Randomly split the motion sequence from 10% to 90%
        win_size = self.datamodule.hparams.win_size
        min_length = int(min(lengths_ref))
        win_size = min(win_size, min_length)
        lengths_ref0 = (
            np.random.randint(0.25 * win_size, 0.75 * win_size)
            // self.down_t
            * self.down_t
        )
        feats_ref0 = feats_ref[:, :lengths_ref0]
        feats_ref1 = feats_ref[:, lengths_ref0:]

        # Forward 2 times
        feats_rst0, loss_commit0, perplexity0 = self.motion_vae(feats_ref0)
        self.motion_vae.reset_buffer("encoder")
        feats_rst1, loss_commit1, perplexity1 = self.motion_vae(feats_ref1)

        loss_commit = loss_commit0 + loss_commit1
        perplexity = perplexity0 + perplexity1

        # Concatenate the results
        feats_rst = torch.cat([feats_rst0, feats_rst1], dim=1)
        min_length = min(feats_ref.shape[1], feats_rst.shape[1])
        feats_ref = feats_ref[:, :min_length]
        feats_rst = feats_rst[:, :min_length]
        joints_ref = self.feats2joints(feats_ref)
        joints_rst = self.feats2joints(feats_rst)

        rs_set = {
            "m_ref": feats_ref,
            "joints_ref": joints_ref,
            "m_rst": feats_rst,
            "joints_rst": joints_rst,
            "loss_commit": loss_commit,
            "perplexity": perplexity,
        }

        return rs_set

    @torch.no_grad()
    def val_forward(self, batch: dict, split: str):
        """
        Forward pass for validation
        """

        feats_ref = batch["motion"]
        lengths = batch["length"]

        # Motion encode & decode
        feats_rst = torch.zeros_like(feats_ref)

        for i in range(len(feats_ref)):
            self.motion_vae.reset_buffer()
            # feats_pred, _, _ = self.motion_vae(feats_ref[i:i + 1, :lengths[i]])
            mid_length = feats_ref.shape[1] // 2
            mid_length = mid_length // self.down_t * self.down_t

            if max(lengths) > 55:
                feats_ref1 = feats_ref[i : i + 1, :mid_length]
                feats_ref2 = feats_ref[i : i + 1, mid_length:]

                codes1, _ = self.motion_vae.encode(feats_ref1)
                # x_buffer = self.motion_vae.x_buffer
                self.motion_vae.reset_buffer()
                # self.motion_vae.x_buffer = x_buffer
                self.motion_vae.seg_encode_flag = False
                # print("Reset buffer")
                codes2, _ = self.motion_vae.encode(feats_ref2)
                self.motion_vae.reset_buffer()
                # codes1 = codes[:, :mid_length]
                # codes2 = codes[:, mid_length - 8:]
                # mask = torch.rand_like(codes1.float()) > 0.5
                # codes1 = codes1 * mask + torch.randint(
                #     0,
                #     self.hparams.codebook_size,
                #     codes1.shape,
                #     device=codes1.device,
                # ) * (~mask)
                # mask = torch.rand_like(codes2.float()) > 0.5
                # codes2 = codes2 * mask + torch.randint(
                #     0,
                #     self.hparams.codebook_size,
                #     codes2.shape,
                #     device=codes2.device,
                # ) * (~mask)

                feats_pred1 = self.motion_vae.decode(codes1)
                feats_pred2 = self.motion_vae.decode(codes2)
                feats_pred = torch.cat((feats_pred1, feats_pred2), dim=1)

            else:
                feats_pred, _, _ = self.motion_vae(feats_ref[i : i + 1, : lengths[i]])

            min_length = min(feats_rst.shape[1], feats_pred.shape[1])
            feats_pred = feats_pred[:, :min_length]

            feats_rst[i : i + 1, : feats_pred.shape[1], :] = feats_pred

        # Recover joints for evaluation
        joints_ref = self.feats2joints(feats_ref)
        joints_rst = self.feats2joints(feats_rst)

        m_ref_denorm = self.datamodule.denormalize(feats_ref)
        m_rst_denorm = self.datamodule.denormalize(feats_rst)

        # Renorm for evaluation
        feats_ref = self.datamodule.renorm4t2m(feats_ref)
        feats_rst = self.datamodule.renorm4t2m(feats_rst)

        if self.trainer.testing and max(lengths) > 55:
            save_path = (
                "/mnt/memData/visualization/vq_causal_babel/"
                + self.hparams.cfg.NAME
            )
            os.makedirs(save_path, exist_ok=True)
            print("Save to {}".format(save_path))
            self._plot_2d(lengths, joints_ref, joints_rst, save_path)

        # Return set
        rs_set = {
            "m_ref": feats_ref,
            "joints_ref": joints_ref,
            "m_rst": feats_rst,
            "joints_rst": joints_rst,
            "length": lengths,
            "m_ref_denorm": m_ref_denorm,
            "m_rst_denorm": m_rst_denorm,
        }

        return rs_set

    def _plot_2d(
        self,
        lengths,
        joints_ref,
        joints_rst=None,
        path=None,
        texts=None,
        texts_gen=None,
    ):
        path = "./results" if path == None else path

        for bid in range(len(lengths)):
            if 5 < bid < 10:
                continue
            print("Saving {}.gif".format(bid))

            try:
                xyz_ref = joints_ref[bid : bid + 1].detach().cpu().numpy()
                xyz = (
                    joints_rst[bid : bid + 1].detach().cpu().numpy()
                    if joints_rst != None
                    else None
                )
            except:
                xyz_ref = joints_ref[bid : bid + 1].detach().numpy()
                xyz = (
                    joints_rst[bid : bid + 1].detach().cpu().numpy()
                    if joints_rst != None
                    else None
                )

            plot_3d.draw_to_batch(
                xyz_ref[:, : lengths[bid], :, :],
                [""],
                [os.path.join(path, "{}_gt.gif".format(bid))],
            )

            if joints_rst != None:
                plot_3d.draw_to_batch(
                    xyz[:, : lengths[bid], :, :],
                    [""],
                    [os.path.join(path, "{}_gen.gif".format(bid))],
                )

            if texts != None:
                with open(os.path.join(path, "{}.txt".format(bid)), "w") as f:
                    f.write(texts[bid])

            if texts_gen != None:
                with open(os.path.join(path, "{}_gen.txt".format(bid)), "w") as f:
                    f.write(texts_gen[bid])
