import os
import torch
import mediapipe as mp
from tensorboardX import SummaryWriter
from models.emotionsyncK import EmotionSyncK
from utils.dataset import Audio2ExpDataset
from utils.loadconfig import load_config
from tqdm import tqdm

def denormalize_params(normalized_params, norm_stats):
    if isinstance(normalized_params, torch.Tensor):
        if isinstance(norm_stats['exp_min'], torch.Tensor):
            min_val = norm_stats['exp_min'].to(normalized_params.device)
            max_val = norm_stats['exp_max'].to(normalized_params.device)
        else:
            min_val = torch.tensor(norm_stats['exp_min'], device=normalized_params.device)
            max_val = torch.tensor(norm_stats['exp_max'], device=normalized_params.device)
        return normalized_params * (max_val - min_val) + min_val
    else:
        min_val = norm_stats['exp_min']
        max_val = norm_stats['exp_max']
        return normalized_params * (max_val - min_val) + min_val

def train_one_epoch(model, train_loader, optimizer, device, epoch, num_epochs, writer, global_step):
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(pbar):
        to_device = ['mel', 'exp', 'ref']
        
        for key in to_device:
                batch[key] = batch[key].to(device)
        
        if batch['norm_stats']['exp_min'] is not None:
            batch['norm_stats']['exp_min'] = batch['norm_stats']['exp_min'].to(device)
            batch['norm_stats']['exp_max'] = batch['norm_stats']['exp_max'].to(device)
        
        optimizer.zero_grad()

        output = model(batch['mel'], batch['ref'])
        # Smooth L1 Loss
        loss_fn = torch.nn.SmoothL1Loss(reduction='none')
        base_loss = loss_fn(output, batch['exp'])  # [b, T, 64]
        # 眼部参数加权
        weights = torch.ones_like(base_loss)
        weights[:, :, 22:30] = 10.0  # 眼部参数更高权重
        weighted_loss = base_loss * weights
        loss_p = weighted_loss.mean()
        # # 时间平滑损失
        # diff = output[:, 1:, :] - output[:, :-1, :]
        # smooth_loss = diff.abs().mean()
        # 总损失
        loss = 20 * loss_p 
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        writer.add_scalar('train/batch_loss', loss.item(), global_step)
        writer.add_scalar('train/loss_p', loss_p.item(), global_step)

        global_step += 1
        
        pbar.set_postfix(loss=f"{loss.item():.4f}, loss_p={2*loss_p.item() * 1e1:.4f}")


    avg_loss = total_loss / len(train_loader)
    
    return avg_loss, global_step

def valid_one_epoch(model, val_loader, device, writer):
    model.eval()
    total_loss = 0.0
    pbar = tqdm(val_loader, desc="Validation")
    
    with torch.no_grad():
        for idx, batch in enumerate(pbar):
            to_device = ['mel', 'exp', 'ref']
        
            for key in to_device:
                batch[key] = batch[key].to(device)
                
            if batch['norm_stats']['exp_min'] is not None:
                batch['norm_stats']['exp_min'] = batch['norm_stats']['exp_min'].to(device)
                batch['norm_stats']['exp_max'] = batch['norm_stats']['exp_max'].to(device)
            
            output = model(batch['mel'], batch['ref'])
            loss_fn = torch.nn.SmoothL1Loss(reduction='none')
            base_loss = loss_fn(output, batch['exp'])
            weights = torch.ones_like(base_loss)
            weights[:, :, 22:30] = 10.0
            weighted_loss = base_loss * weights
            loss_p = weighted_loss.mean()
            # diff = output[:, 1:, :] - output[:, :-1, :]
            # smooth_loss = diff.abs().mean()
            loss = 20 * loss_p
            
            writer.add_scalar('val/batch_loss', loss.item(), idx)
            writer.add_scalar('val/loss_p', loss_p.item(), idx)
            
            total_loss += loss.item()
            
            pbar.set_postfix(loss=f"{loss.item():.4f}, loss_p={2*loss_p.item() * 1e1:.4f}")

    avg_loss = total_loss / len(val_loader)
    
    return avg_loss

def train(cfg):
    device = cfg['training']['device']
    torch.cuda.set_device(device)
    
    num_epochs = cfg['training']['num_epochs']
    
    dataset = Audio2ExpDataset(cfg=cfg['dataset'])
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=cfg['training']['batch_size'],
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=cfg['training']['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    model = EmotionSyncK().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg['training']['learning_rate']))
    
    logdir = cfg['logging']['log_dir']
    check_point_dir = cfg['logging']['checkpoint_dir']
    save_interval = cfg['logging']['save_interval']
    
    os.makedirs(logdir, exist_ok=True)
    os.makedirs(check_point_dir, exist_ok=True)
    writer = SummaryWriter(logdir=logdir)
    
    global_step = 0
    best_val_loss = float('inf')

    checkpoint_files = []

    for epoch in range(num_epochs):

        train_loss, global_step = train_one_epoch(
            model, train_loader, optimizer, device, epoch, num_epochs, writer, global_step
        )

        val_loss = valid_one_epoch(model, val_loader, device, writer)
        

        writer.add_scalar('train/epoch_loss', train_loss, epoch)
        writer.add_scalar('val/epoch_loss', val_loss, epoch)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}')
        
        if (epoch + 1) % save_interval == 0:
            checkpoint_path = os.path.join(check_point_dir, f'aud2exp_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'best_val_loss': best_val_loss
            }, checkpoint_path)
            print(f'Checkpoint saved to {checkpoint_path}')

            checkpoint_files.append(checkpoint_path)

            if len(checkpoint_files) > 3:
                oldest_checkpoint = checkpoint_files.pop(0)
                if os.path.exists(oldest_checkpoint):
                    os.remove(oldest_checkpoint)
                    print(f'Removed old checkpoint: {oldest_checkpoint}')

    final_path = os.path.join(check_point_dir, 'aud2exp_final.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss
    }, final_path)
    print(f'Final model saved to {final_path}')
    
    writer.close()

if __name__ == "__main__":
    config_path = 'configs/train_aud2exp.yaml'
    cfg = load_config(config_path)
    train(cfg)
    print("Training completed!")