# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import time
from typing import Optional
from torch import Tensor
import smplx

from .base import Datastruct, dataclass, Transform

from .rots2rfeats import Rots2Rfeats, Globalvelandy
from .rots2joints import Rots2Joints, SMPLH
from .joints2jfeats import Joints2Jfeats

class SMPLTransform(Transform):
    def __init__(self, batch_size=16, rots2rfeats: Rots2Rfeats = None,
                 rots2joints: Rots2Joints = None,
                 joints2jfeats: Joints2Jfeats = None,
                 **kwargs):
        if rots2rfeats == None:
            rots2rfeats = Globalvelandy(path='./data_loaders/amass/transforms/rots2rfeats/globalvelandy/rot6d/babel-amass/separate_pairs',
                                                                normalization=True,
                                                                pose_rep='rot6d',
                                                                canonicalize=True,
                                                                offset=True,
                                                                name='Globalvelandy')
        if rots2joints == None:
            rots2joints = SMPLH(path='./body_models/smpl_models/smplh',
                    jointstype='smplnh',
                    input_pose_rep='matrix',
                    batch_size=batch_size,
                    gender='male',
                    name='SMPLH')
        if joints2jfeats == None:
            joints2jfeats = None # FIXME : prob not it use

        self.rots2rfeats = rots2rfeats
        self.rots2joints = rots2joints
        self.joints2jfeats = joints2jfeats

    def Datastruct(self, **kwargs):
        return SMPLDatastruct(_rots2rfeats=self.rots2rfeats,
                              _rots2joints=self.rots2joints,
                              _joints2jfeats=self.joints2jfeats,
                              transforms=self,
                              **kwargs)

    def __repr__(self):
        return "SMPLTransform()"

class SlimSMPLTransform(Transform):
    def __init__(self, batch_size=16, rots2rfeats: Rots2Rfeats = None,
                 rots2joints: Rots2Joints = None,
                 **kwargs):
        if rots2rfeats == None:
            rots2rfeats = Globalvelandy(path='./data_loaders/amass/transforms/rots2rfeats/globalvelandy/rot6d/babel-amass/separate_pairs',
                                                                normalization=True,
                                                                pose_rep='rot6d',
                                                                canonicalize=kwargs.get("canonicalize", True),
                                                                offset=True,
                                                                name='Globalvelandy')
        if rots2joints == None:
            rots2joints = SMPLH(path='./body_models/smpl_models/smplh',
                    jointstype='smplnh',
                    input_pose_rep='matrix',
                    batch_size=batch_size,
                    gender='male',
                    name='SMPLH')

        self.rots2rfeats = rots2rfeats
        self.rots2joints = rots2joints

    def SlimDatastruct(self, **kwargs):
        return SlimSMPLDatastruct(_rots2rfeats=self.rots2rfeats,
                              _rots2joints=self.rots2joints,
                              transforms=self,
                              **kwargs)

    def __repr__(self):
        return "SlimSMPLTransform()"


class RotIdentityTransform(Transform):
    def __init__(self, **kwargs):
        return

    def Datastruct(self, **kwargs):
        return RotTransDatastruct(**kwargs)

    def __repr__(self):
        return "RotIdentityTransform()"


@dataclass
class RotTransDatastruct(Datastruct):
    rots: Tensor
    trans: Tensor

    transforms: RotIdentityTransform = RotIdentityTransform()

    def __post_init__(self):
        self.datakeys = ["rots", "trans"]

    def __len__(self):
        return len(self.rots)


@dataclass
class SMPLDatastruct(Datastruct):
    transforms: SMPLTransform
    _rots2rfeats: Rots2Rfeats
    _rots2joints: Rots2Joints
    _joints2jfeats: Joints2Jfeats

    features: Optional[Tensor] = None
    rots_: Optional[RotTransDatastruct] = None
    rfeats_: Optional[Tensor] = None
    joints_: Optional[Tensor] = None
    jfeats_: Optional[Tensor] = None
    vertices_: Optional[Tensor] = None

    def __post_init__(self):
        self.datakeys = ['features', 'rots_', 'rfeats_',
                         'joints_', 'jfeats_', 'vertices_']
        # starting point
        if self.features is not None and self.rfeats_ is None:
            self.rfeats_ = self.features

    @property
    def rots(self):
        # Cached value
        if self.rots_ is not None:
            return self.rots_

        # self.rfeats_ should be defined
        assert self.rfeats_ is not None

        self._rots2rfeats.to(self.rfeats.device)
        self.rots_ = self._rots2rfeats.inverse(self.rfeats)
        return self.rots_

    @property
    def rfeats(self):
        # Cached value
        if self.rfeats_ is not None:
            return self.rfeats_

        # self.rots_ should be defined
        assert self.rots_ is not None

        self._rots2rfeats.to(self.rots.device)
        self.rfeats_ = self._rots2rfeats(self.rots)
        return self.rfeats_

    @property
    def joints(self):
        # Cached value
        if self.joints_ is not None:
            return self.joints_

        self._rots2joints.to(self.rots.device)
        self.joints_ = self._rots2joints(self.rots)
        return self.joints_

    @property
    def jfeats(self):
        # Cached value
        if self.jfeats_ is not None:
            return self.jfeats_

        self._joints2jfeats.to(self.joints.device)
        self.jfeats_ = self._joints2jfeats(self.joints)
        return self.jfeats_
    
    @property
    def vertices(self):
        # Cached value
        if self.vertices_ is not None:
            return self.vertices_

        self._rots2joints.to(self.rots.device)
        self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
        return self.vertices_
    
    def __len__(self):
        return len(self.rfeats)


@dataclass
class SlimSMPLDatastruct(Datastruct):
    transforms: SlimSMPLTransform
    _rots2rfeats: Rots2Rfeats
    _rots2joints: Rots2Joints

    features: Optional[Tensor] = None
    rots_: Optional[RotTransDatastruct] = None
    rfeats_: Optional[Tensor] = None
    joints_: Optional[Tensor] = None

    def __post_init__(self):
        self.datakeys = ['features', 'rots_', 'joints_', 'rfeats_']

        # starting point
        if self.features is not None and self.rfeats_ is None:
            self.rfeats_ = self.features

    @property
    def rots(self):
        # Cached value
        if self.rots_ is not None:
            return self.rots_

        # self.rfeats_ should be defined
        assert self.rfeats_ is not None

        self._rots2rfeats.to(self.rfeats.device)
        self.rots_ = self._rots2rfeats.inverse(self.rfeats)
        return self.rots_

    @property
    def rfeats(self):
        # Cached value
        if self.rfeats_ is not None:
            return self.rfeats_

        # self.rots_ should be defined
        assert self.rots_ is not None

        self._rots2rfeats.to(self.rots.device)
        self.rfeats_ = self._rots2rfeats(self.rots)
        return self.rfeats_

    import time
    @property
    def joints(self):
        # Cached value
        if self.joints_ is not None:
            return self.joints_

        self._rots2joints.to(self.rots.device)
        # t0 = time.time()
        self.joints_ = self._rots2joints(self.rots)
        # t1 = time.time()
        # print(f'rots2joints :: {t1-t0}')
        return self.joints_

    # @property
    # def jfeats(self):
    #     # Cached value
    #     if self.jfeats_ is not None:
    #         return self.jfeats_
    #
    #     self._joints2jfeats.to(self.joints.device)
    #     self.jfeats_ = self._joints2jfeats(self.joints)
    #     return self.jfeats_
    #
    # @property
    # def vertices(self):
    #     # Cached value
    #     if self.vertices_ is not None:
    #         return self.vertices_
    #
    #     self._rots2joints.to(self.rots.device)
    #     self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
    #     return self.vertices_

    def __len__(self):
        return len(self.rfeats)

def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'):
    '''
    type: smpl, smplx smplh and others. Refer to smplx tutorial
    gender: male, female, neutral
    batch_size: an positive integar
    '''
    mtype = model_type.upper()
    if gender != 'neutral':
        if not isinstance(gender, str):
            gender = str(gender.astype(str)).upper()
        else:
            gender = gender.upper()
    else:
        gender = gender.upper()
        ext = 'npz'
    body_model_path = f'body_models/smpl_models/{model_type}/{mtype}_{gender}.{ext}'

    body_model = smplx.create(body_model_path, model_type=type,
                              gender=gender, ext=ext,
                              use_pca=False,
                              num_pca_comps=12,
                              create_global_orient=True,
                              create_body_pose=True,
                              create_betas=True,
                              create_left_hand_pose=True,
                              create_right_hand_pose=True,
                              create_expression=True,
                              create_jaw_pose=True,
                              create_leye_pose=True,
                              create_reye_pose=True,
                              create_transl=True,
                              batch_size=batch_size)
    if device == 'cuda':
        return body_model.cuda()
    else:
        return body_model

