import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from models import PointNet2Reg
from utils import (
    get_args, PoseDataset,
    denormalize_position,
    normalize_quat_tensor, quaternion_geodesic_distance,
    save_prediction_ply, save_pred_txt
)

def main():
    args = get_args()

    train_set = PoseDataset(args.dataset_root, "train", args.num_points)
    test_seen_set = PoseDataset(args.dataset_root, "test_seen", args.num_points)
    test_unseen_set = PoseDataset(args.dataset_root, "test_unseen", args.num_points)

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=False)
    test_seen_loader = DataLoader(test_seen_set, batch_size=1, shuffle=False)
    test_unseen_loader = DataLoader(test_unseen_set, batch_size=1, shuffle=False)

    device = torch.device(args.gpu if torch.cuda.is_available() else "cpu")
    model = PointNet2Reg().to(device)
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-4)

    pos_weight = 10.0
    rot_weight = 1.0

    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}"):
            pc, gt, _, _ = batch
            pc = pc.to(device)
            gt = gt.to(device)
            pred = model(pc)

            pos_pred = []
            pos_gt = []
            quat_pred = []
            quat_gt = []
            for base in [0, 7, 14, 21]:
                pos_pred.append(pred[:, base:base+3])
                pos_gt.append(gt[:, base:base+3])
                quat_pred.append(pred[:, base+3:base+7])
                quat_gt.append(gt[:, base+3:base+7])

            pos_pred = torch.cat(pos_pred, dim=0)
            pos_gt = torch.cat(pos_gt, dim=0)
            quat_pred = torch.cat(quat_pred, dim=0)
            quat_gt = torch.cat(quat_gt, dim=0)

            quat_pred = normalize_quat_tensor(quat_pred)
            pos_loss = torch.mean((pos_pred - pos_gt) ** 2)

            quat_gt = normalize_quat_tensor(quat_gt)
            angles = quaternion_geodesic_distance(quat_pred, quat_gt)
            rot_loss = torch.mean(angles ** 2)

            loss = pos_weight * pos_loss + rot_weight * rot_loss

            optim.zero_grad()
            if torch.isnan(loss).any() or torch.isinf(loss).any():
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optim.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[Epoch {epoch}] Train Loss = {avg_loss:.6f} "
              f"(pos_loss={pos_loss.item():.6f}, rot_loss={rot_loss.item():.6f})")

        def eval_loader(loader, name):
            model.eval()
            pos_errs = []
            rot_errs = []
            with torch.no_grad():
                for batch in loader:
                    pc, gt, _, _ = batch
                    pc = pc.to(device)
                    gt = gt.to(device)
                    pred = model(pc)

                    pos_pred = []
                    pos_gt = []
                    quat_pred = []
                    quat_gt = []
                    for base in [0, 7, 14, 21]:
                        pos_pred.append(pred[:, base:base+3])
                        pos_gt.append(gt[:, base:base+3])
                        quat_pred.append(pred[:, base+3:base+7])
                        quat_gt.append(gt[:, base+3:base+7])

                    pos_pred = torch.cat(pos_pred, dim=0)
                    pos_gt = torch.cat(pos_gt, dim=0)
                    quat_pred = torch.cat(quat_pred, dim=0)
                    quat_gt = torch.cat(quat_gt, dim=0)

                    quat_pred = normalize_quat_tensor(quat_pred)
                    quat_gt = normalize_quat_tensor(quat_gt)
                    angles = quaternion_geodesic_distance(quat_pred, quat_gt)

                    pos_errs.append(torch.mean((pos_pred - pos_gt) ** 2).item())
                    rot_errs.append(torch.mean(angles ** 2).item())

            avg_pos = np.mean(pos_errs)
            avg_rot = np.mean(rot_errs)
            print(f"[Epoch {epoch}] {name} Pos MSE = {avg_pos:.6f}, Rot (rad^2) = {avg_rot:.6f}")

        # Uncomment to enable evaluation
        # eval_loader(test_seen_loader, "test_seen")
        # eval_loader(test_unseen_loader, "test_unseen")

    # Save model
    save_path = args.dataset_root + ".pth"
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    # Inference and save visualization
    def run_test_and_save(loader, split_name):
        print(f"\n==== Predicting {split_name} ====")
        model.eval()
        with torch.no_grad():
            for batch in loader:
                pc, _, ply_path, transform = batch
                pc_t = pc.to(device)
                pred = model(pc_t)[0].cpu().numpy()

                def denorm_pred(pred28, trans):
                    p = pred28.copy()
                    for base in [0, 7, 14, 21]:
                        p[base:base+3] = denormalize_position(p[base:base+3], trans)
                    return p

                transform0 = {
                    "center": transform["center"][0].cpu().numpy(),
                    "scale": transform["scale"][0].cpu().numpy()
                }
                pred_world = denorm_pred(pred, transform0)

                dir_path = os.path.dirname(ply_path[0])
                out_path = os.path.join(dir_path, "pos_predict.ply")
                init_ply_path = ply_path[0]

                save_prediction_ply(pred_world, out_path, init_ply_path)
                save_pred_txt(pred_world, dir_path)
        print(f"Done: {split_name}")

    train_vis_loader = DataLoader(train_set, batch_size=1, shuffle=False)
    run_test_and_save(train_vis_loader, "train")
    run_test_and_save(test_seen_loader, "test_seen")
    run_test_and_save(test_unseen_loader, "test_unseen")

if __name__ == "__main__":
    main()