from __future__ import print_function

import time
import numpy as np
import torch
import torch.nn as nn

from utils.accuracy import AverageMeter, accuracy
from progress.bar import Bar
import copy, time


def train_base(trainloader, model, optimizer, criterion,criterion_supcon=None,epoch=None,args=None):
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()

    bar = Bar('Training', max=len(trainloader))

    for batch_idx, data_tuple in enumerate(trainloader):
        inputs,inputs_aug,inputs_aug_1 = data_tuple
        targets = inputs.y

        data_time.update(time.time() - end)

        x_inputs, edge_index_inputs,edge_attr_inputs,batch_inputs=inputs.x.cuda(non_blocking=True),inputs.edge_index.cuda(non_blocking=True),inputs.edge_attr.cuda(non_blocking=True),inputs.batch.cuda(non_blocking=True)
        x_inputs_aug, edge_index_inputs_aug,edge_attr_inputs_aug,batch_inputs_aug=inputs_aug.x.cuda(non_blocking=True),inputs_aug.edge_index.cuda(non_blocking=True),inputs_aug.edge_attr.cuda(non_blocking=True),inputs_aug.batch.cuda(non_blocking=True)
        x_inputs_aug_1, edge_index_inputs_aug_1,edge_attr_inputs_aug_1,batch_inputs_aug_1=inputs_aug_1.x.cuda(non_blocking=True),inputs_aug_1.edge_index.cuda(non_blocking=True),inputs_aug_1.edge_attr.cuda(non_blocking=True),inputs_aug_1.batch.cuda(non_blocking=True)
        
        targets = targets.cuda(non_blocking=True)

        if args.bcl :
            centers_logits,outputs,x_feat = model.forward_bcl(x_inputs, edge_index_inputs,edge_attr_inputs,batch_inputs)
            centers_logits_1,outputs_1,x_feat_1 = model.forward_bcl(x_inputs_aug, edge_index_inputs_aug,edge_attr_inputs_aug,batch_inputs_aug)
            centers_logits_2,outputs_2,x_feat_2 = model.forward_bcl(x_inputs_aug_1, edge_index_inputs_aug_1,edge_attr_inputs_aug_1,batch_inputs_aug_1)
            
            batch_size = outputs_2.shape[0]

            feat_mlp = torch.cat([x_feat,x_feat_1,x_feat_2])
            logits   = torch.cat([outputs,outputs_1,outputs_2])
            centers  = torch.cat([centers_logits,centers_logits_1,centers_logits_2])
            centers = centers[:args.num_class]
            f1, f2, f3 = torch.split(feat_mlp, [batch_size, batch_size, batch_size], dim=0)
            features = torch.cat([f2.unsqueeze(1), f3.unsqueeze(1)], dim=1)
            logits = torch.split(logits, [batch_size, batch_size, batch_size], dim=0)[0]
            loss = criterion(centers,logits,features,targets)
        else:
            outputs = model(x_inputs, edge_index_inputs,edge_attr_inputs,batch_inputs)
            try:
                loss = criterion(outputs, targets,epoch)
            except:
                loss = criterion(outputs, targets)


        if criterion_supcon != None:
            if args.bcl :
                supconloss = criterion_supcon(torch.stack([outputs,outputs_1],dim=1),labels=targets,epoch = epoch)
                loss = loss+supconloss+0.1*torch.logsumexp(torch.stack([loss,supconloss],dim=-1),dim=-1)
            else:
                outputs1 = model(x_inputs_aug, edge_index_inputs_aug,edge_attr_inputs_aug,batch_inputs_aug)
                loss1 = criterion(outputs, targets,epoch)
                loss = loss+loss1+supconloss+0.1*torch.logsumexp(torch.stack([loss,loss1,supconloss],dim=-1),dim=-1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), inputs.size(0))

        # measure accuracy and record loss
        prec1,prec5 = accuracy(outputs, targets, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))

        # record
        batch_time.update(time.time() - end)
        end = time.time()

        # plot
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \
                     'Loss: {loss:.4f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
        )
        bar.next()
    bar.finish()

    return losses.avg, top1.avg