import torch
import torch.nn as nn
import numpy as np
import os
import yaml
from tqdm import tqdm
from audio2posedataset import Audio2PoseDataset
from audio_encoder import AudioEncoder
from pose_library import PoseLibrary
from pose_retriever import PoseRetriever
from gan_refiner import PoseGAN
from tensorboardX import SummaryWriter

def keep_latest_checkpoints(ckpt_dir, prefix="audio2pose_epoch", keep=3):
    # 获取所有符合前缀的checkpoint文件
    ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith(prefix) and f.endswith('.pt')]
    # 提取epoch数字并排序
    ckpts_sorted = sorted(
        ckpts,
        key=lambda x: int(x.replace(prefix, '').replace('.pt', ''))
    )
    # 删除多余的旧checkpoint
    for ckpt in ckpts_sorted[:-keep]:
        os.remove(os.path.join(ckpt_dir, ckpt))

def train(configs):

    device = configs['device']

    os.makedirs(os.path.dirname(os.path.abspath(configs['save_path'])), exist_ok=True)
    
    log_dir = os.path.join(os.path.dirname(configs['save_path']), 'logs')
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)

    lib = PoseLibrary(window_size=32, stride=16)
    lib.load_library(configs['lib_dir'])

    dataset = Audio2PoseDataset(configs['dataset'])

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)

    audioencoder = AudioEncoder().to(device)
    audioencoder.eval()

    retriever = PoseRetriever(lib)

    gan_model = PoseGAN(
        pose_dim=configs.get('pose_dim', 6),
        audio_dim=configs.get('audio_dim', 512),
        exp_dim=configs.get('exp_dim', 64),
        hidden_dim=configs.get('hidden_dim', 256),
        device=device,
        fusion_type=configs.get('fusion_type', 'kron')
    )
    
    best_loss = float('inf')
    global_step = 0
    
    for epoch in range(configs['epochs']):
        epoch_losses = {
            'd_loss': 0.0,
            'g_loss': 0.0,
            'g_loss_gan': 0.0,
            'g_loss_recon': 0.0,
            'g_loss_smooth': 0.0
        }

        gan_model.netG.train()
        gan_model.netD.train()
        
        for i, data in tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{configs['epochs']}"):
            mel = data['mel'].to(device)
            pose_gt = data['pose_gt'].to(device)
            exp_feat = data['exp_gt'].to(device) 
            

            with torch.no_grad():
                mel_embedding = audioencoder(mel)

            pose_retrieved = retriever.batch_retrieve(mel_embedding)
            pose_retrieved = torch.tensor(np.array(pose_retrieved), dtype=torch.float32).to(device)
            
            losses = gan_model.train_step(pose_retrieved, pose_gt, mel_embedding, exp_feat)
            
            for k, v in losses.items():
                epoch_losses[k] += v
                
            if global_step % configs.get('log_interval', 10) == 0:
                for k, v in losses.items():
                    writer.add_scalar(f'Train/{k}', v, global_step)
                    
            global_step += 1
    
        
        for k in epoch_losses.keys():
            epoch_losses[k] /= len(dataloader)
            writer.add_scalar(f'Epoch/{k}', epoch_losses[k], epoch)
        
        if (epoch + 1) % configs.get('save_interval', 5) == 0:
            checkpoint_path = os.path.join(
                os.path.dirname(configs['save_path']), 
                f"audio2pose_epoch{epoch+1}.pt"
            )
            gan_model.save_models(checkpoint_path)
            print(f"Model saved to {checkpoint_path}")
            # 保留最新3个checkpoint
            keep_latest_checkpoints(os.path.dirname(configs['save_path']))
        
        current_loss = epoch_losses['g_loss']
        if current_loss < best_loss:
            checkpoint_path = os.path.join(
                os.path.dirname(configs['save_path']), 
                f"audio2pose_epoch{epoch+1}.pt"
            )
            best_loss = current_loss
            gan_model.save_models(checkpoint_path)
            print(f"New best model saved to {configs['save_path']}")
            # 保留最新3个checkpoint
            keep_latest_checkpoints(os.path.dirname(configs['save_path']))
    writer.close()
    
    print("Training complete!")

if __name__ == "__main__":
    config_path = '../configs/train_aud2pose.yaml'
    with open(config_path, 'r') as f:
        args = yaml.safe_load(f)
    train(args)