import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import os
import json
import time
from tqdm import tqdm
import argparse
from datetime import datetime
import math
import wandb
from transformers import CLIPModel, CLIPTokenizer, get_cosine_schedule_with_warmup

from MoCLIP import (
    PositionalEncoding, 
    MotionEncoder, 
    ClipMotionAlignModel,
    calculate_recall_at_k,
    evaluate_random_subset
)
from motion_loader import get_dataset_loader
from options.get_opt import get_opt
from argparse import Namespace


class ContrastiveLoss(nn.Module):
    """
    Contrastive learning loss function for training motion-text alignment
    """
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, motion_embeddings, text_embeddings):
        """
        Compute contrastive learning loss
        
        Args:
            motion_embeddings: (B, D) motion features
            text_embeddings: (B, D) text features
        """
        batch_size = motion_embeddings.shape[0]
        
        # Normalize features
        motion_embeddings = F.normalize(motion_embeddings, dim=-1)
        text_embeddings = F.normalize(text_embeddings, dim=-1)
        
        # Compute similarity matrix
        # motion_to_text: (B, B)
        logits_motion_to_text = torch.matmul(motion_embeddings, text_embeddings.T) / self.temperature
        # text_to_motion: (B, B)  
        logits_text_to_motion = torch.matmul(text_embeddings, motion_embeddings.T) / self.temperature
        
        # Labels: diagonal elements are positive samples
        labels = torch.arange(batch_size, device=motion_embeddings.device)
        
        # Compute bidirectional loss
        loss_motion_to_text = self.criterion(logits_motion_to_text, labels)
        loss_text_to_motion = self.criterion(logits_text_to_motion, labels)
        
        # Total loss
        total_loss = (loss_motion_to_text + loss_text_to_motion) / 2
        
        # Compute accuracy
        with torch.no_grad():
            pred_m2t = torch.argmax(logits_motion_to_text, dim=1)
            pred_t2m = torch.argmax(logits_text_to_motion, dim=1)
            acc_m2t = (pred_m2t == labels).float().mean()
            acc_t2m = (pred_t2m == labels).float().mean()
            avg_acc = (acc_m2t + acc_t2m) / 2
        
        return {
            'total_loss': total_loss,
            'loss_m2t': loss_motion_to_text,
            'loss_t2m': loss_text_to_motion,
            'acc_m2t': acc_m2t,
            'acc_t2m': acc_t2m,
            'avg_acc': avg_acc
        }


class MoClipTrainer:
    """
    MoCLIP Trainer
    """
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config.device)
        
        # Set random seed
        torch.manual_seed(config.seed)
        np.random.seed(config.seed)
        
        # Initialize model
        self.setup_model()
        
        # Initialize data loaders
        self.setup_data_loaders()
        
        # Initialize optimizer and scheduler
        self.setup_optimizer()
        
        # Initialize loss function
        self.criterion = ContrastiveLoss(temperature=config.temperature)
        
        # Initialize logging
        self.best_r3 = 0.0
        self.step = 0
        self.epoch = 0
        
        # Create save directory
        os.makedirs(config.save_dir, exist_ok=True)
        
        # Initialize wandb (optional)
        if config.use_wandb:
            wandb.init(
                project="moclip-training",
                name=config.exp_name,
                config=vars(config)
            )
    
    def setup_model(self):
        """Initialize model"""
        print("Initializing model...")
        
        # Load pre-trained CLIP model
        self.clip_model = CLIPModel.from_pretrained(self.config.clip_model_name)
        self.tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name)
        
        # Selective freezing of CLIP parameters
        if self.config.freeze_clip:
            # Completely freeze CLIP
            for param in self.clip_model.parameters():
                param.requires_grad = False
            print("CLIP model parameters completely frozen")
        elif hasattr(self.config, 'clip_finetune_layers') and self.config.clip_finetune_layers > 0:
            # Selective freezing: only fine-tune last few layers
            # Freeze all parameters
            for param in self.clip_model.parameters():
                param.requires_grad = False
            
            # Unfreeze last few layers of text encoder
            text_encoder = self.clip_model.text_model.encoder
            total_layers = len(text_encoder.layers)
            finetune_layers = min(self.config.clip_finetune_layers, total_layers)
            
            # Unfreeze last few layers
            for i in range(total_layers - finetune_layers, total_layers):
                for param in text_encoder.layers[i].parameters():
                    param.requires_grad = True
            
            # Unfreeze final layer norm and projection
            for param in self.clip_model.text_model.final_layer_norm.parameters():
                param.requires_grad = True
            if hasattr(self.clip_model.text_model, 'text_projection'):
                for param in self.clip_model.text_model.text_projection.parameters():
                    param.requires_grad = True
            
            print(f"CLIP text encoder: freeze first {total_layers - finetune_layers} layers, finetune last {finetune_layers} layers")
        else:
            print("All CLIP model parameters are trainable")
        
        # Set correct input_dim based on dataset
        if self.config.dataset_name == 'kit':
            input_dim = 251
        elif self.config.dataset_name == 't2m':
            input_dim = 263
        else:
            input_dim = self.config.input_dim
        
        # Create motion encoder
        self.motion_encoder = MotionEncoder(
            input_dim=input_dim,
            embed_dim=self.config.embed_dim,
            num_heads=self.config.num_heads,
            num_layers=self.config.num_layers,
            dim_feedforward=self.config.dim_feedforward,
            dropout=self.config.dropout,
            max_seq_length=self.config.max_seq_length
        )
        
        # Create complete model
        self.model = ClipMotionAlignModel(
            clip_model=self.clip_model,
            motion_encoder=self.motion_encoder,
            temperature=self.config.temperature
        ).to(self.device)
        
        # Count trainable parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Model created successfully")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
    
    def setup_data_loaders(self):
        """Initialize data loaders"""
        print("Initializing data loaders...")
        
        # Set dataset parameters
        opt = Namespace()
        opt.dataset_name = self.config.dataset_name
        opt.batch_size = self.config.batch_size
        opt.device = self.device
        opt.max_length = self.config.max_text_length
        opt.feat_bias = 5
        opt.max_text_len = 20
        opt.unit_length = 4
        
        if self.config.dataset_name == 't2m':
            opt.joints_num = 22
            opt.dim_pose = 263
            opt.max_motion_length = 196
            opt.radius = 4
            opt.fps = 20
            opt.data_root = './dataset/HumanML3D'
            opt.motion_dir = os.path.join(opt.data_root, 'new_joint_vecs')
            opt.text_dir = os.path.join(opt.data_root, 'texts')
            opt.mean_path = os.path.join(opt.data_root, 'Mean.npy')
            opt.std_path = os.path.join(opt.data_root, 'Std.npy')
            opt.split_dir = os.path.join(opt.data_root, 'train_val.txt')
            opt.meta_dir = './checkpoints/t2m/clip/meta'
            opt.eval_meta_dir = './dataset'
            opt.glove_dir = './dataset'
        elif self.config.dataset_name == 'kit':
            opt.joints_num = 21
            opt.dim_pose = 251
            opt.max_motion_length = 196
            opt.radius = 4
            opt.fps = 12.5
            opt.data_root = './dataset/KIT-ML'
            opt.motion_dir = os.path.join(opt.data_root, 'new_joint_vecs')
            opt.text_dir = os.path.join(opt.data_root, 'texts')
            opt.mean_path = os.path.join(opt.data_root, 'Mean.npy')
            opt.std_path = os.path.join(opt.data_root, 'Std.npy')
            opt.split_dir = os.path.join(opt.data_root, 'train.txt')
            opt.meta_dir = './checkpoints/kit/clip/meta'
            opt.eval_meta_dir = './dataset'
            opt.glove_dir = './dataset'
        
        # Create data loaders
        self.train_loader = get_dataset_loader(
            opt, batch_size=self.config.batch_size, 
            split='train', mode='train'
        )
        
        self.val_loader = get_dataset_loader(
            opt, batch_size=self.config.eval_batch_size,
            split='test', mode='train'
        )
        
        print(f"Training data: {len(self.train_loader)} batches")
        print(f"Validation data: {len(self.val_loader)} batches")
    
    def setup_optimizer(self):
        """Initialize optimizer and learning rate scheduler"""
        
        # Layered learning rates
        if self.config.freeze_clip:
            # Only train motion encoder
            params = [{'params': self.motion_encoder.parameters(), 'lr': self.config.learning_rate}]
        else:
            # Set different learning rates for different parts
            clip_params = [p for p in self.clip_model.parameters() if p.requires_grad]
            motion_params = list(self.motion_encoder.parameters())
            
            if clip_params:
                # CLIP uses smaller learning rate, motion encoder uses normal learning rate
                params = [
                    {'params': clip_params, 'lr': self.config.learning_rate * 0.1, 'name': 'clip'},
                    {'params': motion_params, 'lr': self.config.learning_rate, 'name': 'motion'}
                ]
                print(f"Using layered learning rates: CLIP={self.config.learning_rate * 0.1:.2e}, Motion={self.config.learning_rate:.2e}")
            else:
                params = [{'params': motion_params, 'lr': self.config.learning_rate}]
        
        self.optimizer = optim.AdamW(
            params,
            weight_decay=self.config.weight_decay,
            betas=(0.9, 0.999)
        )
        
        # Learning rate scheduler
        total_steps = len(self.train_loader) * self.config.num_epochs
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )
        
        print(f"Optimizer initialized, total training steps: {total_steps}")
    
    def train_epoch(self):
        """Train one epoch"""
        self.model.train()
        total_loss = 0
        total_acc = 0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
        
        for batch_idx, batch_data in enumerate(pbar):
            captions, motions, lengths = batch_data
            
            # Data preprocessing
            captions = [cap.lower() for cap in captions]
            
            # Text encoding
            text_inputs = self.tokenizer(
                captions,
                padding=True,
                truncation=True,
                max_length=self.config.max_text_length,
                return_tensors="pt"
            )
            input_ids = text_inputs["input_ids"].to(self.device)
            attention_mask = text_inputs["attention_mask"].to(self.device)
            
            # Motion data processing
            if isinstance(motions, list):
                motions = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motions], dim=0)
            motions = motions.float().to(self.device)
            lengths = lengths.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            
            motion_embeddings, text_embeddings = self.model(
                motions, lengths, input_ids, attention_mask
            )
            
            # Compute loss
            loss_dict = self.criterion(motion_embeddings, text_embeddings)
            loss = loss_dict['total_loss']
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            if self.config.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
            
            self.optimizer.step()
            self.scheduler.step()
            
            # Update statistics
            total_loss += loss.item()
            total_acc += loss_dict['avg_acc'].item()
            
            # Update progress bar
            current_lr = self.scheduler.get_last_lr()[0]
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{loss_dict['avg_acc'].item()*100:.1f}%",
                'lr': f"{current_lr:.2e}"
            })
            
            # Log to wandb
            if self.config.use_wandb and self.step % self.config.log_interval == 0:
                wandb.log({
                    'train/loss': loss.item(),
                    'train/acc_m2t': loss_dict['acc_m2t'].item(),
                    'train/acc_t2m': loss_dict['acc_t2m'].item(),
                    'train/avg_acc': loss_dict['avg_acc'].item(),
                    'train/lr': current_lr,
                    'step': self.step
                })
            
            self.step += 1
            
            # Regular validation
            if self.step % self.config.eval_interval == 0:
                val_metrics = self.validate()
                self.model.train()  # Return to training mode
        
        avg_loss = total_loss / len(self.train_loader)
        avg_acc = total_acc / len(self.train_loader)
        
        return avg_loss, avg_acc
    
    def validate(self):
        """Validate model performance"""
        print("\nStarting validation...")
        self.model.eval()
        
        all_motion_embeddings = []
        all_text_embeddings = []
        
        with torch.no_grad():
            for batch_data in tqdm(self.val_loader, desc="Validating"):
                captions, motions, lengths = batch_data
                
                # Data preprocessing
                captions = [cap.lower() for cap in captions]
                
                # Text encoding
                text_inputs = self.tokenizer(
                    captions,
                    padding=True,
                    truncation=True,
                    max_length=self.config.max_text_length,
                    return_tensors="pt"
                )
                input_ids = text_inputs["input_ids"].to(self.device)
                attention_mask = text_inputs["attention_mask"].to(self.device)
                
                # Motion data processing
                if isinstance(motions, list):
                    motions = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motions], dim=0)
                motions = motions.float().to(self.device)
                lengths = lengths.to(self.device)
                
                # Forward pass
                motion_embeddings, text_embeddings = self.model(
                    motions, lengths, input_ids, attention_mask
                )
                
                all_motion_embeddings.append(motion_embeddings.cpu().numpy())
                all_text_embeddings.append(text_embeddings.cpu().numpy())
        
        # Merge all embeddings
        motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
        text_embeddings = np.concatenate(all_text_embeddings, axis=0)
        
        # Calculate recall@k metrics
        recall_results = calculate_recall_at_k(
            motion_embeddings, text_embeddings, 
            k_values=[1, 3, 5, 10]
        )
        
        # Focus on R@3
        r3_score = recall_results['Average R@3']
        
        print(f"Validation results:")
        print(f"  R@1: {recall_results['Average R@1']*100:.2f}%")
        print(f"  R@3: {r3_score*100:.2f}%")
        print(f"  R@5: {recall_results['Average R@5']*100:.2f}%")
        print(f"  R@10: {recall_results['Average R@10']*100:.2f}%")
        
        # 记录到wandb
        if self.config.use_wandb:
            wandb.log({
                'val/r1': recall_results['Average R@1'],
                'val/r3': r3_score,
                'val/r5': recall_results['Average R@5'],
                'val/r10': recall_results['Average R@10'],
                'val/m2t_r3': recall_results['Motion-to-Text R@3'],
                'val/t2m_r3': recall_results['Text-to-Motion R@3'],
                'step': self.step
            })
        
        # 保存最佳模型
        if r3_score > self.best_r3:
            self.best_r3 = r3_score
            self.save_model('best_model.pt')
            print(f"🎉 New best R@3: {r3_score*100:.2f}%")
        
        return recall_results
    
    def save_model(self, filename):
        """Save model"""
        checkpoint = {
            'epoch': self.epoch,
            'step': self.step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_r3': self.best_r3,
            'config': vars(self.config)
        }
        
        save_path = os.path.join(self.config.save_dir, filename)
        torch.save(checkpoint, save_path)
        print(f"Model saved to: {save_path}")
    
    def load_model(self, checkpoint_path):
        """Load model"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.epoch = checkpoint['epoch']
        self.step = checkpoint['step']
        self.best_r3 = checkpoint['best_r3']
        
        print(f"Model loaded from {checkpoint_path}")
        print(f"Resume training - Epoch: {self.epoch}, Step: {self.step}, Best R@3: {self.best_r3*100:.2f}%")
    
    def train(self):
        """Main training loop"""
        print("🚀 Starting MoCLIP training...")
        print(f"Configuration: {self.config}")
        
        for epoch in range(self.epoch, self.config.num_epochs):
            self.epoch = epoch
            
            print(f"\n{'='*50}")
            print(f"Epoch {epoch+1}/{self.config.num_epochs}")
            print(f"{'='*50}")
            
            # Train one epoch
            train_loss, train_acc = self.train_epoch()
            
            print(f"Training results: Loss={train_loss:.4f}, Acc={train_acc*100:.2f}%")
            
            # 每个epoch结束后验证
            val_metrics = self.validate()
            
            # 保存检查点
            if (epoch + 1) % self.config.save_interval == 0:
                self.save_model(f'checkpoint_epoch_{epoch+1}.pt')
        
        print("🎉 Training completed!")
        print(f"Best R@3: {self.best_r3*100:.2f}%")
        
        if self.config.use_wandb:
            wandb.finish()


def get_config():
    """获取训练配置"""
    parser = argparse.ArgumentParser(description='MoCLIP Training')
    
    # 基本配置
    parser.add_argument('--exp_name', type=str, default=f'moclip_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
    parser.add_argument('--save_dir', type=str, default='./checkpoints/moclip_training')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--seed', type=int, default=42)
    
    # 数据配置
    parser.add_argument('--dataset_name', type=str, default='kit', help='数据集名称: t2m 或 kit')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--max_text_length', type=int, default=77)
    
    # 模型配置
    parser.add_argument('--clip_model_name', type=str, default='openai/clip-vit-large-patch14')
    parser.add_argument('--freeze_clip', action='store_true', help='冻结CLIP参数')
    parser.add_argument('--clip_finetune_layers', type=int, default=2, help='CLIP微调的层数（从最后开始）')
    parser.add_argument('--input_dim', type=int, default=251, help='输入维度，t2m: 263, kit: 251')
    parser.add_argument('--embed_dim', type=int, default=768)
    parser.add_argument('--num_heads', type=int, default=8)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--dim_feedforward', type=int, default=2048)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--max_seq_length', type=int, default=196)
    parser.add_argument('--temperature', type=float, default=0.07)
    
    # 训练配置
    parser.add_argument('--num_epochs', type=int, default=50)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--grad_clip', type=float, default=1.0)
    
    # 评估和保存
    parser.add_argument('--eval_interval', type=int, default=1000)
    parser.add_argument('--save_interval', type=int, default=5)
    parser.add_argument('--log_interval', type=int, default=100)
    
    # wandb
    parser.add_argument('--use_wandb', action='store_true', help='使用wandb记录')
    
    # 恢复训练
    parser.add_argument('--resume', type=str, default=None, help='恢复训练的检查点路径')
    
    return parser.parse_args()


if __name__ == "__main__":
    config = get_config()
    
    # 创建训练器
    trainer = MoClipTrainer(config)
    
    # 如果指定了恢复路径，加载检查点
    if config.resume:
        trainer.load_model(config.resume)
    
    # 开始训练
    trainer.train() 