"""
ExpressionMapperTrainer: ARKit-to-FLAME Regression Training Script

This script trains a RR+ER mapper from ARKit blendshape coefficients (51 dims) 
to FLAME expression + jaw parameters (103 dims) for a specific subject. 

Folder Structure:
    data/
        <subject_id>/
            arkit_weights.npy       # (N, 51) training inputs from ARKit mocap
            flame_weights.npz       # FLAME tracking outputs from VHAP
    data/w_er.npy            # (103, 51) expression regularization matrix

Usage:
    python train_expression_mapper.py --subject <subject_id>

Output:
    Saves the learned mapping matrix to:
        data/<subject_id>/mat.npy    # (51, 103) trained ARKit-to-FLAME matrix (transposed in code)
"""
import import_helper
from flame_model.flame import FlameHead
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import numpy as np

class ExpressionMapperTrainer:
    def __init__(
        self, 
        alpha: float = 1e-1, 
        lambda_er: float = 1.2,
        lambda_mask: float = 5e3, 
        W_er: np.ndarray = None, 
        batch_size: int = 64
    ):
        self.lambda_mask = lambda_mask
        self.n_expr = 100
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.flame_model = FlameHead(shape_params=300, expr_params=self.n_expr).to(self.device).eval()
        self.batch_size = batch_size
        if W_er is None:
            W_er = np.load("flame_fitting/output/shape_keys.npy")
        # self.shape = torch.from_numpy(fixed_flame_params["shape"]).to(self.device).float()
        print(f"Using device: {self.device}")
        masks = self.flame_model.mask.v
        self.masks = {
            'eyes': torch.cat([masks.eye_region, masks.left_eyeball, masks.right_eyeball])
        }
        self.W_er = torch.tensor(W_er, dtype=torch.float32, device=self.device) # (103, 51)
        
        shapedirs = self.flame_model.shapedirs   # Tensor shape (V, 3, shape + expr)
        self.expr_slice = shapedirs[..., 300:] # (V, 3, expr)
        self.expr_basis = self.expr_slice.norm(dim=1).to(self.device) # collapse xyz -> magnitude and move to device
        
        self.W = nn.Parameter(torch.zeros(103, 51, device=self.device)) # (103, 51)
        self.alpha        = alpha        # ridge reg weight
        self.lambda_guide  = lambda_er

    
    def compute_loss_params(self, X: torch.Tensor, P: torch.Tensor) -> torch.Tensor:
        """
        X: (N,51) ARKit blendshape inputs
        P: (N,103) FLAME target parameters (expr + jaw)
        returns: total loss = data loss + lambda * mask penalty
        """
        P_pred = X @ self.W.T                 # (N,103)
        loss_mse  = ((P_pred - P)**2).mean()
        loss_reg  = self.alpha * (self.W**2).mean()
        loss_data = loss_mse + loss_reg
        loss_guid = ((self.W.T - self.W_er)**2).mean()
        return loss_data + self.lambda_guide * loss_guid
    
    def train(
        self,
        X_np: 'np.ndarray',
        P_np: 'np.ndarray',
        epochs: int = 1000,
        lr: float = 1e-2,
    ):
        X = torch.from_numpy(X_np).float().to(self.device)
        P = torch.from_numpy(P_np).float().to(self.device)
        loader = DataLoader(TensorDataset(X, P), batch_size=self.batch_size, shuffle=True)

        optimizer = optim.Adam([self.W], lr=lr)
        for epoch in range(1, epochs+1):
            total_loss = 0.0
            for Xb, Pb in loader:
                optimizer.zero_grad()
                loss = self.compute_loss_params(Xb, Pb)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(loader)
            if epoch == 1 or epoch % 100 == 0:
                print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.4f}")

        return self.W.detach().cpu().numpy().T

def process_flame_data(flame_path):
    """
    Stack the expression and jaw pose data to form (N,103)-dimensional training data
    """
    flame_data = np.load(flame_path, allow_pickle=True)
    expr = flame_data['expr']
    jaw_pose = flame_data['jaw_pose']
    fixed_data = {
        'shape' : np.tile(flame_data["shape"][None, :], (1, 1))
    }
    combined_matrix = np.hstack((expr, jaw_pose))
    return combined_matrix, fixed_data
   
if __name__ == "__main__":
    import argparse
    import time
    start_time = time.time()
    parser = argparse.ArgumentParser(description="Train a matrix to map ARKit blendshapes to FLAME parameters.")
    parser.add_argument("--subject", type=str, default="306", help="Subject name for data processing")
    args = parser.parse_args()
    subject = args.subject
    
    batch_size = 32
    X_np = np.load(f"data/{subject}/arkit_weights.npy")
    P_np, fixed_flame_params = process_flame_data(f"data/{subject}/flame_weights.npz")
    W_er = np.load("data/w_er.npy")
    assert X_np.shape[0] == P_np.shape[0], "Mismatched number of training samples"
    
    trainer = ExpressionMapperTrainer(
        lambda_er=1.0, 
        W_er=W_er, 
    )
    
    W = trainer.train(X_np, P_np, epochs=700, lr=1e-3)
    np.save(f"checkpoints/{subject}/mat.npy", W)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Training completed in {elapsed_time:.2f} seconds.")
    print(f"Training complete. matrix saved to checkpoints/{subject}/mat.npy")