import os
import sys
import logging
import datetime
import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error, mean_absolute_error, r2_score
from omegaconf import OmegaConf
import argparse
import wandb

from mld.models.pose_penetration_detector import PosePenetrationDetector
from mld.datasets.pose_penetration_dataset import PosePenetrationDataset, get_dataloader

def calculate_joint_metrics_simple(all_preds, all_labels):
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    
    i, j = np.triu_indices(22, k=1)
    
    
    preds_flat = all_preds[:, i, j].flatten()
    labels_flat = all_labels[:, i, j].flatten()
    
    
    return (
        accuracy_score(labels_flat, preds_flat),
        precision_score(labels_flat, preds_flat, zero_division=0),
        recall_score(labels_flat, preds_flat, zero_division=0),
        f1_score(labels_flat, preds_flat, zero_division=0)
    )

def main():
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, required=True, help='config file path')
    args = parser.parse_args()
    
    
    cfg = OmegaConf.load(args.cfg)
    
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    
    name_time_str = osp.join(cfg.name, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
    cfg.output_dir = osp.join(cfg.logging.log_dir, name_time_str)
    os.makedirs(cfg.output_dir, exist_ok=True)
    os.makedirs(f"{cfg.output_dir}/checkpoints", exist_ok=True)
    
    
    stream_handler = logging.StreamHandler(sys.stdout)
    file_handler = logging.FileHandler(osp.join(cfg.output_dir, 'output.log'))
    handlers = [file_handler, stream_handler]
    logging.basicConfig(level=logging.INFO,
                       format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
                       datefmt="%m/%d/%Y %H:%M:%S",
                       handlers=handlers)
    logger = logging.getLogger(__name__)
    
    
    OmegaConf.save(cfg, osp.join(cfg.output_dir, 'config.yaml'))
    
    
    wandb.init(
        project="pose-penetration-detector",
        name=name_time_str,
        config=OmegaConf.to_container(cfg, resolve=True),
        dir=cfg.output_dir
    )
    
    
    task_type = getattr(cfg.model, 'task_type', 'binary')
    
    
    scaling_factor = getattr(cfg.model, 'scaling_factor', 1.0)
    
    
    model = PosePenetrationDetector(
        input_dim=cfg.model.input_dim,
        hidden_dims=cfg.model.hidden_dims,
        output_dim=cfg.model.output_dim,
        dropout=cfg.model.dropout,
        activation=cfg.model.activation,
        task_type=task_type,
        scaling_factor=scaling_factor
    ).to(device)
    
    
    train_loader = get_dataloader(cfg, 'train')
    val_loader = get_dataloader(cfg, 'val')
    
    
    if task_type == 'binary':
        criterion = nn.CrossEntropyLoss()
        best_metric = 0  
    elif task_type == 'joint_score':
        criterion = nn.MSELoss()
        best_metric = float('inf')  
    elif task_type == 'joint_binary':
        criterion = nn.CrossEntropyLoss()
        best_metric = 0  
    elif task_type == 'joint_joint_binary':
        criterion = nn.CrossEntropyLoss()
        best_metric = 0  
    else:
        raise ValueError(f"Unknown task_type: {task_type}")
    
    optimizer = optim.Adam(
        model.parameters(),
        lr=cfg.train.learning_rate,
        weight_decay=cfg.train.weight_decay
    )
    
    
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=cfg.train.num_epochs,
        T_mult=1
    )
    
    
    for epoch in range(cfg.train.num_epochs):
        logger.info(f"\nEpoch {epoch+1}/{cfg.train.num_epochs}")
        
        
        train_metrics = train_epoch(model, train_loader, criterion, optimizer, device, epoch, task_type)
        logger.info("Train metrics: %s", train_metrics)
        
        
        val_metrics = validate(model, val_loader, criterion, device, epoch, task_type)
        logger.info("Validation metrics: %s", val_metrics)
        
        
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step()
        
        
        wandb.log({
            "epoch": epoch,
            "learning_rate": current_lr,
            **{f"train/{k}": v for k, v in train_metrics.items()},
            **{f"val/{k}": v for k, v in val_metrics.items()}
        })
        
        
        if task_type == 'binary':
            current_metric = val_metrics['f1']
            is_better = current_metric > best_metric
        elif task_type == 'joint_score':
            current_metric = val_metrics['mse']
            is_better = current_metric < best_metric
        elif task_type == 'joint_binary':
            current_metric = val_metrics['f1']
            is_better = current_metric > best_metric
        elif task_type == 'joint_joint_binary':
            current_metric = val_metrics['f1']
            is_better = current_metric > best_metric
        else:
            raise ValueError(f"Unknown task_type: {task_type}")
        
        if is_better:
            best_metric = current_metric
            save_path = os.path.join(cfg.output_dir, 'checkpoints', 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_metrics': val_metrics,
                'task_type': task_type
            }, save_path)
            if task_type == 'joint_score':
                metric_name = 'MSE'
            else:  
                metric_name = 'F1'
            logger.info(f"Saved best model to {save_path} with {metric_name}: {best_metric:.3f}")
        
        
        if (epoch + 1) % cfg.logging.save_every_n_epochs == 0:
            save_path = os.path.join(cfg.output_dir, 'checkpoints', f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_metrics': val_metrics,
                'task_type': task_type
            }, save_path)
            logger.info(f"Saved checkpoint to {save_path}")
    
    
    wandb.finish()

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, task_type):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc='Training')):
        poses = batch['pose'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(poses)
        
        if task_type == 'joint_binary':
            
            outputs_reshaped = outputs.view(-1, 2)  # [64*22, 2]
            labels_reshaped = labels.view(-1)       # [64*22]
            loss = criterion(outputs_reshaped, labels_reshaped)
        elif task_type == 'joint_joint_binary':
            
            outputs_reshaped = outputs.view(-1, 2)  # [64*22*22, 2]
            labels_reshaped = labels.view(-1)       # [64*22*22]
            loss = criterion(outputs_reshaped, labels_reshaped)
        else:
            loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if task_type == 'binary':
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        elif task_type == 'joint_score':
            all_preds.extend(outputs.detach().cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        elif task_type == 'joint_binary':
            
            batch_size = outputs.shape[0]
            
            preds = torch.argmax(outputs, dim=2)  # [batch_size, 22]
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        elif task_type == 'joint_joint_binary':
            
            batch_size = outputs.shape[0]
            
            preds = torch.argmax(outputs, dim=3)  # [batch_size, 22, 22]
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        
        if batch_idx % 10 == 0:  
            wandb.log({
                "train/batch_loss": loss.item(),
                "train/batch": epoch * len(train_loader) + batch_idx
            })
    
    
    if task_type == 'binary':
        metrics = {
            'loss': total_loss / len(train_loader),
            'accuracy': accuracy_score(all_labels, all_preds),
            'precision': precision_score(all_labels, all_preds),
            'recall': recall_score(all_labels, all_preds),
            'f1': f1_score(all_labels, all_preds)
        }
    elif task_type == 'joint_score':
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        metrics = {
            'loss': total_loss / len(train_loader),
            'mse': mean_squared_error(all_labels, all_preds),
            'mae': mean_absolute_error(all_labels, all_preds),
            'r2': r2_score(all_labels, all_preds)
        }
    elif task_type == 'joint_binary':
        
        all_preds = np.array(all_preds)  # [N, 22]
        all_labels = np.array(all_labels)  # [N, 22]
        
        
        joint_accuracies = []
        joint_precisions = []
        joint_recalls = []
        joint_f1s = []
        
        for joint_idx in range(22):
            joint_preds = all_preds[:, joint_idx]
            joint_labels = all_labels[:, joint_idx]
            
            joint_accuracies.append(accuracy_score(joint_labels, joint_preds))
            joint_precisions.append(precision_score(joint_labels, joint_preds, zero_division=0))
            joint_recalls.append(recall_score(joint_labels, joint_preds, zero_division=0))
            joint_f1s.append(f1_score(joint_labels, joint_preds, zero_division=0))
        
        metrics = {
            'loss': total_loss / len(train_loader),
            'accuracy': np.mean(joint_accuracies),
            'precision': np.mean(joint_precisions),
            'recall': np.mean(joint_recalls),
            'f1': np.mean(joint_f1s)
        }
    elif task_type == 'joint_joint_binary':
        
        all_preds = np.array(all_preds)  # [N, 22, 22]
        all_labels = np.array(all_labels)  # [N, 22, 22]
        
        
        accuracy, precision, recall, f1 = calculate_joint_metrics_simple(all_preds, all_labels)
        
        metrics = {
            'loss': total_loss / len(train_loader),
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    
    return metrics

def validate(model, val_loader, criterion, device, epoch, task_type):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc='Validation')):
            poses = batch['pose'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(poses)
            
            if task_type == 'joint_binary':
                
                outputs_reshaped = outputs.view(-1, 2)  # [64*22, 2]
                labels_reshaped = labels.view(-1)       # [64*22]
                loss = criterion(outputs_reshaped, labels_reshaped)
            elif task_type == 'joint_joint_binary':
                
                outputs_reshaped = outputs.view(-1, 2)  # [64*22*22, 2]
                labels_reshaped = labels.view(-1)       # [64*22*22]
                loss = criterion(outputs_reshaped, labels_reshaped)
            else:
                loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            if task_type == 'binary':
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            elif task_type == 'joint_score':
                all_preds.extend(outputs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            elif task_type == 'joint_binary':
                
                batch_size = outputs.shape[0]
                
                preds = torch.argmax(outputs, dim=2)  # [batch_size, 22]
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            elif task_type == 'joint_joint_binary':
                
                batch_size = outputs.shape[0]
                
                preds = torch.argmax(outputs, dim=3)  # [batch_size, 22, 22]
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            
            
            if batch_idx % 10 == 0:  
                wandb.log({
                    "val/batch_loss": loss.item(),
                    "val/batch": epoch * len(val_loader) + batch_idx
                })
    
    
    if task_type == 'binary':
        metrics = {
            'loss': total_loss / len(val_loader),
            'accuracy': accuracy_score(all_labels, all_preds),
            'precision': precision_score(all_labels, all_preds),
            'recall': recall_score(all_labels, all_preds),
            'f1': f1_score(all_labels, all_preds)
        }
    elif task_type == 'joint_score':
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        metrics = {
            'loss': total_loss / len(val_loader),
            'mse': mean_squared_error(all_labels, all_preds),
            'mae': mean_absolute_error(all_labels, all_preds),
            'r2': r2_score(all_labels, all_preds)
        }
    elif task_type == 'joint_binary':
        
        all_preds = np.array(all_preds)  # [N, 22]
        all_labels = np.array(all_labels)  # [N, 22]
        
        
        joint_accuracies = []
        joint_precisions = []
        joint_recalls = []
        joint_f1s = []
        
        for joint_idx in range(22):
            joint_preds = all_preds[:, joint_idx]
            joint_labels = all_labels[:, joint_idx]
            
            joint_accuracies.append(accuracy_score(joint_labels, joint_preds))
            joint_precisions.append(precision_score(joint_labels, joint_preds, zero_division=0))
            joint_recalls.append(recall_score(joint_labels, joint_preds, zero_division=0))
            joint_f1s.append(f1_score(joint_labels, joint_preds, zero_division=0))
        
        metrics = {
            'loss': total_loss / len(val_loader),
            'accuracy': np.mean(joint_accuracies),
            'precision': np.mean(joint_precisions),
            'recall': np.mean(joint_recalls),
            'f1': np.mean(joint_f1s)
        }
    elif task_type == 'joint_joint_binary':
        
        all_preds = np.array(all_preds)  # [N, 22, 22]
        all_labels = np.array(all_labels)  # [N, 22, 22]
        
        
        accuracy, precision, recall, f1 = calculate_joint_metrics_simple(all_preds, all_labels)
        
        metrics = {
            'loss': total_loss / len(val_loader),
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
    
    return metrics

if __name__ == '__main__':
    main() 