#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Main training script for point cloud classification with contrastive PointNet
- Argument parsing
- Dataset and DataLoader initialization
- Model/optimizer/scheduler setup
- Training loop (with contrastive + classification loss)
- Evaluation on test_seen and test_unseen datasets
- Logging and model checkpoint saving
"""
import os
import random
import argparse
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# Import custom modules
from dataset import PointCloudDataset
from model import SimplePointNet
from loss_functions import contrastive_nt_xent, classification_loss
from evaluation import evaluate


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Train PointNet for Point Cloud Classification with Contrastive Loss")
    # Data paths
    parser.add_argument('--data_root', type=str, default='./FlatLab_Stage_1_modify/Data_op_PointCloud_Obj',
                        help='Root directory of the point cloud dataset')
    parser.add_argument('--train_dir', type=str, default='train', help='Training set subdirectory')
    parser.add_argument('--test_seen_dir', type=str, default='test_seen', help='Seen test set subdirectory')
    parser.add_argument('--test_unseen_dir', type=str, default='test_unseen', help='Unseen test set subdirectory')
    parser.add_argument('--classes', nargs='+', default=['Strategy_A', 'Strategy_B', 'Strategy_C'],
                        help='List of class names for classification')
    # Model hyperparameters
    parser.add_argument('--num_points', type=int, default=2048,
                        help='Fixed number of points for each point cloud')
    parser.add_argument('--batch_size', type=int, default=32, help='Training batch size')
    parser.add_argument('--epochs', type=int, default=150, help='Total training epochs')
    parser.add_argument('--lr', type=float, default=0.0005, help='Initial learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='L2 weight decay for optimizer')
    # Training settings
    parser.add_argument('--device', type=str, default='cuda:0', help='Computation device (cuda:0/cpu)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--log_file', type=str, default='log.txt', help='Log file name')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for DataLoader')
    parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Directory for saving checkpoints/logs')
    # Contrastive loss settings
    parser.add_argument('--contrast_weight', type=float, default=0.5,
                        help='Weight of contrastive loss in total loss')
    parser.add_argument('--contrast_temp', type=float, default=0.3,
                        help='Temperature for contrastive NT-Xent loss')

    args = parser.parse_args()

    # --------------------------
    # Initialization Settings
    # --------------------------
    # Create save directory if not exists
    os.makedirs(args.save_dir, exist_ok=True)
    # Set random seeds for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    # Set computation device
    if args.device.startswith('cuda') and torch.cuda.is_available():
        device = torch.device(args.device)
        torch.cuda.manual_seed(args.seed)
    else:
        device = torch.device('cpu')
    print(f"Using computation device: {device}")

    # --------------------------
    # Dataset and DataLoader
    # --------------------------
    # Dataset paths
    train_root = os.path.join(args.data_root, args.train_dir)
    test_seen_root = os.path.join(args.data_root, args.test_seen_dir)
    test_unseen_root = os.path.join(args.data_root, args.test_unseen_dir)

    # Initialize datasets
    train_ds = PointCloudDataset(train_root, args.classes, args.num_points, augmentation=True)
    test_seen_ds = PointCloudDataset(test_seen_root, args.classes, args.num_points, augmentation=False)
    test_unseen_ds = PointCloudDataset(test_unseen_root, args.classes, args.num_points, augmentation=False)

    # Initialize DataLoaders
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)
    test_seen_loader = DataLoader(test_seen_ds, batch_size=args.batch_size, shuffle=False,
                                  num_workers=args.num_workers)
    test_unseen_loader = DataLoader(test_unseen_ds, batch_size=args.batch_size, shuffle=False,
                                    num_workers=args.num_workers)

    # Print dataset statistics
    print(f"Dataset Statistics: Train={len(train_ds)}, Test_seen={len(test_seen_ds)}, Test_unseen={len(test_unseen_ds)}")

    # --------------------------
    # Model, Optimizer, Scheduler
    # --------------------------
    num_classes = len(args.classes)
    # Initialize model and move to device
    model = SimplePointNet(num_classes=num_classes).to(device)
    # Initialize optimizer (Adam)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # Initialize learning rate scheduler (StepLR)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    # --------------------------
    # Logging Setup
    # --------------------------
    log_fpath = os.path.join(args.save_dir, args.log_file)
    logf = open(log_fpath, 'a', encoding='utf-8')
    logf.write(f"\n--- New Training Run ---\n")
    logf.write(f"Seed: {args.seed}, Num Points: {args.num_points}, Batch Size: {args.batch_size}\n")
    logf.write(f"Contrast Weight: {args.contrast_weight}, Contrast Temp: {args.contrast_temp}\n")
    logf.flush()

    # --------------------------
    # Training Loop
    # --------------------------
    best_seen_acc = 0.0  # Best accuracy on test_seen set
    best_unseen_acc = 0.0  # Best accuracy on test_unseen set

    for epoch in range(1, args.epochs + 1):
        # Set model to training mode
        model.train()
        running_total_loss = 0.0
        running_ce_loss = 0.0
        running_con_loss = 0.0
        step = 0

        # Progress bar for training
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}")
        for batch in pbar:
            # Unpack batch data
            pts, labels, _, obj_ids = batch
            pts = pts.to(device)
            labels = labels.to(device)
            # Convert obj_ids to string list
            obj_ids_list = [str(x) for x in obj_ids]

            # Zero out gradients
            optimizer.zero_grad()
            # Forward pass (return logits and global features)
            logits, feats = model(pts, return_feat=True)
            # Calculate losses
            ce_loss = classification_loss(logits, labels)
            con_loss = contrastive_nt_xent(feats, labels, obj_ids_list, temperature=args.contrast_temp)
            total_loss = ce_loss + args.contrast_weight * con_loss

            # Backward pass and optimize
            total_loss.backward()
            optimizer.step()

            # Update running losses
            running_total_loss += total_loss.item()
            running_ce_loss += ce_loss.item()
            running_con_loss += con_loss.item()
            step += 1

            # Update progress bar postfix
            pbar.set_postfix({
                'CE Loss': f"{running_ce_loss/step:.4f}",
                'Con Loss': f"{running_con_loss/step:.4f}",
                'Total Loss': f"{running_total_loss/step:.4f}"
            })

        # Update learning rate scheduler
        scheduler.step()

        # Calculate average losses for the epoch
        avg_ce_loss = running_ce_loss / max(1, step)
        avg_con_loss = running_con_loss / max(1, step)
        avg_total_loss = running_total_loss / max(1, step)

        # --------------------------
        # Evaluate Model
        # --------------------------
        train_acc, _, _ = evaluate(model, train_loader, device)
        seen_acc, seen_per_class, seen_mis = evaluate(model, test_seen_loader, device)
        unseen_acc, unseen_per_class, unseen_mis = evaluate(model, test_unseen_loader, device)

        # --------------------------
        # Save Misclassified Samples
        # --------------------------
        # Save test_seen misclassified
        seen_mis_f = os.path.join(args.save_dir, f"misclassified_test_seen_epoch{epoch}.txt")
        with open(seen_mis_f, 'w', encoding='utf-8') as f:
            for fname, true, pred in seen_mis:
                f.write(f"{fname}\ttrue={true}\tpred={pred}\n")
        # Save test_unseen misclassified
        unseen_mis_f = os.path.join(args.save_dir, f"misclassified_test_unseen_epoch{epoch}.txt")
        with open(unseen_mis_f, 'w', encoding='utf-8') as f:
            for fname, true, pred in unseen_mis:
                f.write(f"{fname}\ttrue={true}\tpred={pred}\n")

        # --------------------------
        # Save Best Model Checkpoints
        # --------------------------
        improved_seen = False
        improved_unseen = False
        # Update best test_seen model
        if seen_acc > best_seen_acc:
            best_seen_acc = seen_acc
            improved_seen = True
            torch.save(model.state_dict(), os.path.join(args.save_dir, "best_model_test_seen.pth"))
            # Save best test_seen misclassified
            with open(os.path.join(args.save_dir, "best_misclassified_test_seen.txt"), 'w', encoding='utf-8') as f:
                for fname, true, pred in seen_mis:
                    f.write(f"{fname}\ttrue={true}\tpred={pred}\n")
        # Update best test_unseen model
        if unseen_acc > best_unseen_acc:
            best_unseen_acc = unseen_acc
            improved_unseen = True
            torch.save(model.state_dict(), os.path.join(args.save_dir, "best_model_test_unseen.pth"))
            # Save best test_unseen misclassified
            with open(os.path.join(args.save_dir, "best_misclassified_test_unseen.txt"), 'w', encoding='utf-8') as f:
                for fname, true, pred in unseen_mis:
                    f.write(f"{fname}\ttrue={true}\tpred={pred}\n")

        # --------------------------
        # Log and Print Results
        # --------------------------
        log_line = (
            f"Epoch {epoch:3d} | Train CE: {avg_ce_loss:.6f} | Train Con: {avg_con_loss:.6f} | Train Total: {avg_total_loss:.6f} | "
            f"Train Acc: {train_acc:.4f} | Test Seen Acc: {seen_acc:.4f} | Test Unseen Acc: {unseen_acc:.4f} | "
            f"Best Seen: {best_seen_acc:.4f} | Best Unseen: {best_unseen_acc:.4f}\n"
        )
        print(log_line.strip())
        logf.write(log_line)
        # Log per-class accuracy
        logf.write(f"  Test Seen Per-Class Acc: {seen_per_class}\n")
        logf.write(f"  Test Unseen Per-Class Acc: {unseen_per_class}\n")
        # Log model improvement
        if improved_seen:
            logf.write(f"  -> Updated Best Test Seen Model (Acc: {best_seen_acc:.4f})\n")
        if improved_unseen:
            logf.write(f"  -> Updated Best Test Unseen Model (Acc: {best_unseen_acc:.4f})\n")
        logf.flush()

    # --------------------------
    # Training Finished
    # --------------------------
    logf.write("--- Training Finished Successfully ---\n")
    logf.close()
    print(f"Training Finished! All Logs and Checkpoints are saved in: {args.save_dir}")


if __name__ == '__main__':
    main()