# train.py

import torch
from torch_geometric.data import Batch
from samplers import apply_virtual_knockdown
from losses import info_nce_loss, loss_aug
from utils import to_device

def train_epoch(model, loader, sampler, optimizer, args, device):
    model.train()
    total_i = total_a = total = 0.0
    V = len(sampler.kd_meta)

    for batch in loader:
        graphs = batch.to_data_list()
        B = len(graphs)

        # Get a, b, cand, log_p, weight from sampler
        a, b, cand, log_p, weight = sampler.sample(args.tau_aug)
        # Broadcast log_p, weight to Tensor of length B
        if log_p.dim() == 0:
            log_p = log_p.expand(B)
        if weight.dim() == 0:
            weight = weight.expand(B)

        # Batch student views from list
        G_a = Batch.from_data_list([apply_virtual_knockdown(g, a) for g in graphs])
        G_b = Batch.from_data_list([apply_virtual_knockdown(g, b) for g in graphs])
        G_a, G_b = to_device(G_a, device), to_device(G_b, device)

        Za = model(G_a.x, G_a.edge_index, G_a.edge_attr).view(B, -1, args.proj_out)
        Zb = model(G_b.x, G_b.edge_index, G_b.edge_attr).view(B, -1, args.proj_out)

        # InfoNCE loss
        l_info = info_nce_loss(Za, Zb, args.tau_nce)

        # Batch process emb_c with subset cand
        flat_c = [apply_virtual_knockdown(g, kd) for g in graphs for kd in cand]
        Bc = Batch.from_data_list(flat_c)
        Bc = to_device(Bc, device)
        Zc = model(Bc.x, Bc.edge_index, Bc.edge_attr).view(B, len(cand), -1, args.proj_out)

        # Graph-level KL(Aug) loss
        l_aug = loss_aug(Za, Zb, Zc, log_p, weight, V, args.tau_aug)

        # Total loss
        loss = l_info + l_aug
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_i += l_info.item()
        total_a += l_aug.item()
        total   += loss.item()

    return total_i/len(loader), total_a/len(loader), total/len(loader)


def test_epoch(model, loader, sampler, args, device):
    model.eval()
    total_i = total_a = total = 0.0
    V = len(sampler.kd_meta)

    with torch.no_grad():
        for batch in loader:
            graphs = batch.to_data_list()
            B = len(graphs)

            a, b, cand, log_p, weight = sampler.sample(args.tau_aug)
            if log_p.dim() == 0:
                log_p = log_p.expand(B)
            if weight.dim() == 0:
                weight = weight.expand(B)

            G_a = Batch.from_data_list([apply_virtual_knockdown(g, a) for g in graphs])
            G_b = Batch.from_data_list([apply_virtual_knockdown(g, b) for g in graphs])
            G_a, G_b = to_device(G_a, device), to_device(G_b, device)

            Za = model(G_a.x, G_a.edge_index, G_a.edge_attr).view(B, -1, args.proj_out)
            Zb = model(G_b.x, G_b.edge_index, G_b.edge_attr).view(B, -1, args.proj_out)

            l_info = info_nce_loss(Za, Zb, args.tau_nce)

            flat_c = [apply_virtual_knockdown(g, kd) for g in graphs for kd in cand]
            Bc = Batch.from_data_list(flat_c)
            Bc = to_device(Bc, device)
            Zc = model(Bc.x, Bc.edge_index, Bc.edge_attr).view(B, len(cand), -1, args.proj_out)

            l_aug = loss_aug(Za, Zb, Zc, log_p, weight, V, args.tau_aug)

            total_i += l_info.item()
            total_a += l_aug.item()
            total   += (l_info + l_aug).item()

    return total_i/len(loader), total_a/len(loader), total/len(loader)
