from re import I
import torch

from utils import accuracy, save_model, F1Score, AverageMeter

from tqdm import tqdm

from torch_geometric.loader import DataLoader

class GraphTrainer:
    def __init__(self, types='transductive'):
        if types == 'transductive':
            self.trainer = self.train_transductive
            self.evaluate = self.test_transductive
        elif types == 'inductive':
            self.trainer = self.train_inductive
            self.evaluate = self.test_inductive
        else:
            print('Error ! Select transductive or inductive')

    def train(self, net, criterion, train_data, val_data, test_data, optimizer, writer, args, progress_bar=True):
        if progress_bar:
            outer = tqdm(total=len(range(1, args.epoch + 1)), desc='Epoch', ncols=150, postfix={'Best_epoch': f'{args.best_epoch:d}', 'Train_acc': f'{args.train_acc:.4f}', 'Best_val_acc': f'{args.best_val_acc:.4f}', 'Test_acc': f'{args.test_acc:.4f}'}, position=0)
        for epoch in range(1, args.epoch + 1):
            self.trainer(net, criterion, train_data, optimizer, writer, epoch, args)
            self.evaluate(net, val_data, test_data, writer, epoch, args)
            if progress_bar:
                outer.set_postfix(Best_epoch=f'{args.best_epoch:d}', Train_acc=f'{args.train_acc:.4f}', Best_val_acc=f'{args.best_val_acc:.4f}', Test_acc=f'{args.test_acc:.4f}')
                outer.update(1)
        if progress_bar:
            outer.close()
        if writer is not None:
            writer.add_hparams({'lr': args.lr, 'wd': args.wd, 'N': args.nprocs}, {'Best/val_acc': args.best_val_acc, 'Best/test_acc': args.test_acc, 'Best/epoch': args.best_epoch})

    def train_iterative(self, net, criterion, train_loader, optimizer, writer, epoch, args):
        net.train()
        train_acc = AverageMeter()
        for tr in train_loader:
            tr = tr.to(args.device)
            optimizer.zero_grad()
            logits = net(tr.x, tr.edge_index, tr.edge_index_interface if hasattr(tr, 'edge_index_interface') else None)
            loss = criterion(logits[tr.train_mask], tr.y[tr.train_mask])
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                train_acc.update(accuracy(logits[tr.train_mask], tr.y[tr.train_mask]), tr.y[tr.train_mask].size(0))

        # Return result
        args.train_acc = train_acc.avg
        if writer is not None:
            writer.add_scalar('Loss/train', loss, epoch)
            writer.add_scalar('Accuray/train', args.train_acc, epoch)

    @torch.no_grad()
    def test_iterative(self, net, val_loader, test_loader, writer, epoch, args):
        if val_loader == test_loader:
            net.eval()
            val_acc = AverageMeter()
            test_acc = AverageMeter()
            for data in test_loader:
                data = data.to(args.device)
                logits = net(data.x, data.edge_index, data.edge_index_interface if hasattr(data, 'edge_index_interface') else None)
                val_acc_tmp, test_acc_tmp = [accuracy(logits[mask], data.y[mask]) for _, mask in data('val_mask', 'test_mask')]
                val_acc.update(val_acc_tmp, data.y[data.val_mask].size(0))
                test_acc.update(test_acc_tmp, data.y[data.test_mask].size(0))
            val_acc = val_acc.avg
            test_acc = test_acc.avg
            if val_acc > args.best_val_acc:
                args.best_val_acc = val_acc
                args.test_acc = test_acc
                args.best_epoch = epoch
                save_model(args, net)
            # Return result
            if writer is not None:
                writer.add_scalar('Accuray/val', val_acc, epoch)
                writer.add_scalar('Accuray/test', test_acc, epoch)
        else:
            print('Error ! it is not transductive learning !!')

    def train_direct(self, net, criterion, data, optimizer, writer, epoch, args):
        net.train()
        data = data.to(args.device)
        optimizer.zero_grad()
        logits = net(data.x, data.edge_index, data.edge_index_interface if hasattr(data, 'edge_index_interface') else None)
        loss = criterion(logits[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        args.train_acc = accuracy(logits[data.train_mask], data.y[data.train_mask])
        # Return result
        if writer is not None:
            writer.add_scalar('Loss/train', loss, epoch)
            writer.add_scalar('Accuray/train', args.train_acc, epoch)

    def train_transductive(self, net, criterion, data, optimizer, writer, epoch, args):
        if isinstance(data, DataLoader):
            self.train_iterative(net, criterion, data, optimizer, writer, epoch, args)
        else:
            self.train_direct(net, criterion, data, optimizer, writer, epoch, args)

    @torch.no_grad()
    def test_transductive(self, net, val_data, test_data, writer, epoch, args):
        if isinstance(val_data, DataLoader):
            self.test_iterative(net, val_data, test_data, writer, epoch, args)
        else:
            self.test_direct(net, val_data, test_data, writer, epoch, args)

    @torch.no_grad()
    def test_direct(self, net, val_data, test_data, writer, epoch, args):
        if val_data == test_data:
            val_data = val_data.to(args.device)
            data = val_data
            net.eval()
            logits = net(data.x, data.edge_index, data.edge_index_interface if hasattr(data, 'edge_index_interface') else None)
            val_acc, test_acc = [accuracy(logits[mask], data.y[mask]) for _, mask in data('val_mask', 'test_mask')]
            if val_acc > args.best_val_acc:
                args.best_val_acc = val_acc
                args.test_acc = test_acc
                args.best_epoch = epoch
                save_model(args, net)
            # Return result
            if writer is not None:
                writer.add_scalar('Accuray/val', val_acc, epoch)
                writer.add_scalar('Accuray/test', test_acc, epoch)
        else:
            print('Error ! it is not transductive learning !!')


    def train_inductive(self, net, criterion, train_loader, optimizer, writer, epoch, args):
        net.train()
        train_acc = AverageMeter()
        f1_score = F1Score('micro')
        for tr in train_loader:
            if isinstance(tr, list):
                for batch in tr:
                    batch = batch.to(args.device)
                    optimizer.zero_grad()
                    logits = net(batch.x, batch.edge_index, batch.edge_index_interface if hasattr(batch, 'edge_index_interface') else None)
                    loss = criterion(logits, batch.y)
                    loss.backward()
                    optimizer.step()
                    with torch.no_grad():
                        logits = (logits > 0).float()
                        train_acc.update(f1_score(logits, batch.y), batch.y.size(0))
            else:
                tr = tr.to(args.device)
                optimizer.zero_grad()
                logits = net(tr.x, tr.edge_index, tr.edge_index_interface if hasattr(tr, 'edge_index_interface') else None)
                loss = criterion(logits, tr.y)
                loss.backward()
                optimizer.step()
                with torch.no_grad():
                    logits = (logits > 0).float()
                    train_acc.update(f1_score(logits, tr.y), tr.y.size(0))

        # Return result
        args.train_acc = train_acc.avg
        if writer is not None:
            writer.add_scalar('Loss/train', loss, epoch)
            writer.add_scalar('Accuray/train', args.train_acc, epoch)

    @torch.no_grad()
    def epoch_iteration_inference(self, net, data_loader, args):
        acc = AverageMeter()
        f1_score = F1Score('micro')
        for data in data_loader:
            if isinstance(data, list):
                for batch in data:
                    batch = batch.to(args.device)
                    logits = net(batch.x, batch.edge_index, batch.edge_index_interface if hasattr(batch, 'edge_index_interface') else None)
                    logits = (logits > 0).float()
                    acc.update(f1_score(logits, batch.y), batch.x.size(0))
            else:
                data = data.to(args.device)
                logits = net(data.x, data.edge_index, data.edge_index_interface if hasattr(data, 'edge_index_interface') else None)
                logits = (logits > 0).float()
                acc.update(f1_score(logits, data.y), data.x.size(0))
        return acc.avg

    @torch.no_grad()
    def test_inductive(self, net, val_loader, test_loader, writer, epoch, args):
        net.eval()
        val_acc = self.epoch_iteration_inference(net, val_loader, args)
        test_acc = self.epoch_iteration_inference(net, test_loader, args)

        # Return result
        if val_acc > args.best_val_acc:
            args.best_val_acc = val_acc
            args.test_acc = test_acc
            args.best_epoch = epoch
            save_model(args, net)
        if writer is not None:
            writer.add_scalar('Accuray/val', val_acc, epoch)
            writer.add_scalar('Accuray/test', test_acc, epoch)