import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import wandb
from dataset import ProcessedDataset, collate_fn
from model import FingerprintGenerator
from loss import InfoNCELoss
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--preprocessed_data_dir", type=str, default="../preprocessed_vox_train", help="Directory containing preprocessed .pt files")
parser.add_argument("--checkpoint_dir", type=str, default="./chkpt", help="Directory to save checkpoints")
parser.add_argument("--num_epochs", type=int, default=30, help="Number of epochs to train")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--temperature", type=float, default=0.05, help="Temperature for InfoNCE loss")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")

args = parser.parse_args()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ProcessedDataset(args.preprocessed_data_dir, train=True)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

model = FingerprintGenerator()

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
else:
    print("Using single GPU")
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=args.num_epochs, 
    eta_min=1e-5
)

loss_contrastive_fn = InfoNCELoss(temperature=args.temperature)
wandb.init(
    project="SpeeCheck",
    config=vars(args)
)

for epoch in range(args.num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}")

    for batch in pbar:
        anchor_emb = batch["anchor_emb"].to(device)
        anchor_mask = batch["anchor_mask"].to(device)
        pos_emb = batch["positive_emb"].to(device)
        pos_mask = batch["positive_mask"].to(device)
        neg_emb = batch["negative_emb"].to(device)
        neg_mask = batch["negative_mask"].to(device)        

        optimizer.zero_grad()
        anchor_emb = model(anchor_emb, anchor_mask)
        pos_emb = model(pos_emb, pos_mask)
        neg_emb = model(neg_emb, neg_mask)

        loss_contrastive = loss_contrastive_fn(anchor_emb, pos_emb, neg_emb)

        loss = loss_contrastive
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            P = pos_emb.size(0) // anchor_emb.size(0)
            N = neg_emb.size(0) // anchor_emb.size(0)
            anchor_pos = anchor_emb.repeat_interleave(P, dim=0)
            anchor_neg = anchor_emb.repeat_interleave(N, dim=0)
            pos_sim = F.cosine_similarity(anchor_pos, pos_emb, dim=1)
            neg_sim = F.cosine_similarity(anchor_neg, neg_emb, dim=1)

            anchor_sim = F.cosine_similarity(
                anchor_emb.unsqueeze(1),      # [B, 1, D]
                anchor_emb.unsqueeze(0),      # [1, B, D]
                dim=-1                        
            )
            mask = ~torch.eye(anchor_sim.size(0), dtype=torch.bool, device=anchor_sim.device)
            cross_anchor_sim = anchor_sim.masked_select(mask)
            
        
        wandb.log({
            "epoch": epoch + 1,
            "batch_loss": loss.item(),
            "pos_sim": pos_sim.mean().item(),
            "neg_sim": neg_sim.mean().item(),
            "cross_anchor_sim": cross_anchor_sim.mean().item(),
            "lr": scheduler.get_last_lr()[0],
        })
        pbar.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{args.num_epochs}], Loss: {avg_loss:.4f}")
    scheduler.step()
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    torch.save(model.state_dict(), f"{args.checkpoint_dir}/SpeeCheck_epoch_{epoch+1}.pth")
