import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

from models import ShapeClassifier
from data import BasicPointCloudDataset

NUM_CLASSES = 5  # Number of classification categories


def normalized_triplet_loss(orig_emb, pos_emb, neg_emb, margin=5.0):
    orig_emb = F.normalize(orig_emb, p=2, dim=1)
    pos_emb = F.normalize(pos_emb, p=2, dim=1)
    neg_emb = F.normalize(neg_emb, p=2, dim=1)
    return nn.TripletMarginLoss(margin=margin)(orig_emb, pos_emb, neg_emb)


def test(model, dataloader, loss_function, device, args):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    count = 0
    label_correct = {label: 0 for label in range(NUM_CLASSES)}
    label_total = {label: 0 for label in range(NUM_CLASSES)}

    with torch.no_grad():
        for batch in dataloader:
            pcl, info = batch['point_cloud'].to(device), batch['info']
            labels = info['class'].to(device).long()
            output = model(pcl.permute(0, 2, 1).unsqueeze(2)).squeeze()
            output = output[:, :NUM_CLASSES]
            preds = output.argmax(dim=1)

            loss = loss_function(output, labels)
            total_loss += loss.item()
            total_acc += (preds == labels).float().mean().item()
            count += 1

            for i in range(NUM_CLASSES):
                correct = ((preds == i) & (labels == i)).sum().item()
                total = (labels == i).sum().item()
                label_correct[i] += correct
                label_total[i] += total

    average_loss = total_loss / count
    accuracy = total_acc / count
    label_accuracies = {
        label: (label_correct[label] / label_total[label]) if label_total[label] != 0 else 0.0
        for label in range(NUM_CLASSES)
    }

    return average_loss, accuracy, label_accuracies


def train_and_test(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.use_wandb:
        # wandb.login()
        wandb.init(project=args.wandb_proj, name=args.exp_name)

    print(f"Using device: {device}")
    print(args)

    if args.output_dim >= NUM_CLASSES:
        train_file = "train_surfaces_05X05.h5"
        test_file = "test_surfaces_05X05.h5"
    else:
        train_file = "train_surfaces_05X05_no_edge.h5"
        test_file = "test_surfaces_05X05_no_edge.h5"

    train_dataset = BasicPointCloudDataset(file_path=train_file, args=args)
    test_dataset = BasicPointCloudDataset(file_path=test_file, args=args)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    model = ShapeClassifier(args).to(device)
    print(f'Number of parameters: {sum(p.numel() for p in model.parameters())}')

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
    milestones = [args.lr_jumps * i for i in range(1, args.epochs // args.lr_jumps + 1)]
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    contrastive_loss_fn = nn.TripletMarginLoss(margin=args.contr_margin)
    criterion = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    emb_start = 0 if args.output_dim == NUM_CLASSES else NUM_CLASSES

    for epoch in range(args.epochs):
        model.train()
        total_classification_loss = 0.0
        total_contrastive_loss = 0.0
        total_contrastive_pos_loss = 0.0
        total_contrastive_neg_loss = 0.0
        total_accuracy = 0.0
        count = 0

        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{args.epochs}', leave=False) as pbar:
            for batch in pbar:
                pcl, info = batch['point_cloud'].to(device), batch['info']
                if args.cube:
                    pcl /= pcl.abs().max()

                labels = info['class'].to(device).long()
                output = model(pcl.permute(0, 2, 1).unsqueeze(2)).squeeze()
                logits = output[:, :NUM_CLASSES]
                embeddings = output[:, emb_start:]

                classification_loss = criterion(logits, labels) if args.classification else torch.tensor(0.0, device=device)
                contrastive_loss = torch.tensor(0.0, device=device)

                if args.contr_loss_weight > 0:
                    pos_pcl = batch['point_cloud2'].to(device)
                    neg_pcl = batch['contrastive_point_cloud'].to(device)
                    pos_output = model(pos_pcl.permute(0, 2, 1).unsqueeze(2)).squeeze()
                    neg_output = model(neg_pcl.permute(0, 2, 1).unsqueeze(2)).squeeze()
                    pos_emb = pos_output[:, emb_start:]
                    neg_emb = neg_output[:, emb_start:]
                    contrastive_loss = normalized_triplet_loss(embeddings, pos_emb, neg_emb, margin=args.contr_margin)
                    total_contrastive_pos_loss += mse_loss(embeddings, pos_emb).item()
                    total_contrastive_neg_loss += mse_loss(embeddings, neg_emb).item()

                loss = classification_loss + args.contr_loss_weight * contrastive_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_classification_loss += classification_loss.item()
                total_contrastive_loss += contrastive_loss.item()
                preds = logits.argmax(dim=1)
                total_accuracy += (preds == labels).float().mean().item()
                count += 1
                pbar.set_postfix(train_loss=classification_loss.item())

        avg_train_loss = total_classification_loss / count
        avg_train_acc = total_accuracy / count
        avg_contrastive_loss = total_contrastive_loss / count
        avg_pos_mse = total_contrastive_pos_loss / count
        avg_neg_mse = total_contrastive_neg_loss / count

        print(f'Epoch {epoch+1}: Contrastive Loss = {avg_contrastive_loss:.4f}, '
              f'Pos MSE = {avg_pos_mse:.4f}, Neg MSE = {avg_neg_mse:.4f}')

        if args.classification:
            test_loss, test_acc, label_accs = test(model, test_loader, criterion, device, args)
            print(f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
            print({"epoch": epoch, "train_loss": avg_train_loss, "test_loss": test_loss,
                   "acc_train": avg_train_acc, "acc_test": test_acc})
            for label, acc in label_accs.items():
                print(f"label_{label}: {acc:.4f}")

        scheduler.step()

        if args.use_wandb:
            metrics = {
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "acc_train": avg_train_acc,
            }
            if args.classification:
                metrics.update({
                    "test_loss": test_loss,
                    "acc_test": test_acc,
                    **{f"label_{k}": v for k, v in label_accs.items()}
                })
            if args.contr_loss_weight > 0:
                metrics["contrastive_loss"] = avg_contrastive_loss
            wandb.log(metrics)

    return model


def configArgsPCT():
    parser = argparse.ArgumentParser(description='Point Cloud Recognition')
    parser.add_argument('--wandb_proj', type=str, default='canonNet')
    parser.add_argument('--exp_name', type=str, default='exp')
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--use_wandb', type=int, default=1)
    parser.add_argument('--contr_margin', type=float, default=1.0)
    parser.add_argument('--use_lap_reorder', type=int, default=1)
    parser.add_argument('--lap_eigenvalues_dim', type=int, default=0)
    parser.add_argument('--use_second_deg', type=int, default=1)
    parser.add_argument('--lpe_normalize', type=int, default=1)
    parser.add_argument('--std_dev', type=float, default=0.05)
    parser.add_argument('--max_curve_diff', type=float, default=2.0)
    parser.add_argument('--min_curve_diff', type=float, default=0.05)
    parser.add_argument('--clip', type=float, default=0.25)
    parser.add_argument('--contr_loss_weight', type=float, default=0.1)
    parser.add_argument('--lpe_dim', type=int, default=0)
    parser.add_argument('--use_xyz', type=int, default=1)
    parser.add_argument('--classification', type=int, default=1)
    parser.add_argument('--rotate_data', type=int, default=1)
    parser.add_argument('--cube', type=int, default=0)
    parser.add_argument('--num_neurons_per_layer', type=int, default=64)
    parser.add_argument('--num_mlp_layers', type=int, default=5)
    parser.add_argument('--output_dim', type=int, default=5)
    parser.add_argument('--lr_jumps', type=int, default=15)
    parser.add_argument('--sampled_points', type=int, default=20)
    return parser.parse_args()


if __name__ == '__main__':
    args = configArgsPCT()
    model = train_and_test(args)
    torch.save(model.state_dict(), f'{args.exp_name}.pt')
