import logging
import os
import os.path as osp
import time

import numpy as np
import torch
from tqdm import tqdm

from data import get_data, get_metric
from components.backbone import get_model
from utils.training import adjust_learning_rate, add_labels
from utils.utils import setup_seed, loss_fn, pred_fn, Dict
from utils.utils import mask_to_index

from omegaconf import OmegaConf
import hydra

log = logging.getLogger(__name__)


def get_degree(adj, undirected=True):
    deg_in = adj.sum(dim=1).to(torch.long)
    deg_out = adj.sum(dim=0).to(torch.long)
    if undirected and not deg_in.equal(deg_out):
        return deg_in + deg_out
    else:
        return deg_in


def train(model, optimizer, data, train_mask, grad_norm=None, use_label=False, mask_rate=1.):
    model.train()

    idx = mask_to_index(train_mask)
    mask = torch.rand(idx.shape) < mask_rate
    train_idx = idx[mask]
    if use_label:
        x = add_labels(data.x, data.y, idx[~mask])
    else:
        x = data.x

    optimizer.zero_grad()
    out = model(x, data.adj_t)

    # Label Reuse
    # n_classes = (data.y.max() + 1).item()
    # unlabel_idx = torch.cat([train_idx, mask_to_index(data.val_mask), mask_to_index(data.test_mask)])
    # x[unlabel_idx, -n_classes:] = F.softmax(out[unlabel_idx], dim=-1).detach()
    # out = model(x, data.adj_t)

    loss = loss_fn(out[train_idx], data.y[train_idx])
    if hasattr(model, 'load_balance_loss'):
        loss += model.load_balance_loss
    loss.backward()
    if grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
    optimizer.step()
    return loss


def mini_train(model, optimizer, loader, grad_norm=None):
    model.train()
    total_loss = 0
    for batch in tqdm(loader):
        optimizer.zero_grad()
        y_hat = model(batch.x, batch.adj_t)
        loss = loss_fn(y_hat, batch.y)
        loss.backward()
        if grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
        optimizer.step()
        total_loss += loss.item() / batch.batch_size
    return total_loss


@torch.no_grad()
def test(model, metric, data, masks, use_label=False):
    model.eval()

    if use_label:
        x = add_labels(data.x, data.y, data.train_mask)
    else:
        x = data.x

    out = model(x, data.adj_t)

    # out, hop_weights = model(x, data.adj_t)
    # torch.save(hop_weights, '/home/gangda/workspace/adapt-hop/processed/hop_weights/amazon_ratings_gamlp.pt')

    # Label Reuse
    # n_classes = (data.y.max() + 1).item()
    # unlabel_idx = torch.cat([mask_to_index(data.val_mask), mask_to_index(data.test_mask)])
    # x[unlabel_idx, -n_classes:] = F.softmax(out[unlabel_idx], dim=-1).detach()
    # out = model(x, data.adj_t)

    accs, losses = [], []
    pred, y = pred_fn(out, data.y)
    for mask in masks:
        metric.reset()
        metric(pred[mask], y[mask])
        accs.append(metric.compute())

        loss = loss_fn(out[mask], data.y[mask])
        losses.append(loss)

    return accs, losses, out


@hydra.main(version_base=None, config_path='conf', config_name='config')
def main(conf):
    log.critical(OmegaConf.to_yaml(conf))
    train_conf = conf.train
    model_conf = conf.model
    data_conf = conf.dataset

    ## working dir
    dataset_dir = osp.join(conf.data_dir, 'pyg')
    curve_dir = osp.join(conf.proc_dir, 'curve')
    proc_dir = conf.proc_dir
    if conf.ens > 1:
        proc_dir = osp.join(proc_dir, 'ens')
    elif conf.get('degradation', False):
        proc_dir = osp.join(proc_dir, 'degradation')
    elif conf.log_logit:
        proc_dir = osp.join(proc_dir, 'logit')
    os.makedirs(dataset_dir, exist_ok=True)
    os.makedirs(curve_dir, exist_ok=True)
    os.makedirs(proc_dir, exist_ok=True)
    os.makedirs(conf.ckpt_dir, exist_ok=True)
    ckpt_path = osp.join(conf.ckpt_dir, '{}_{}_{}_{}.tar'.format(
        model_conf.name, model_conf.conv_layers, os.getpid(), int(time.time())))

    ## for fixed dataset split
    setup_seed(0)

    ## dataset
    data, num_features, num_classes, _ = get_data(root=dataset_dir, **data_conf)
    metric = get_metric(data_conf.name, num_classes)
    if train_conf.use_label:
        num_features = num_features + num_classes

    ## model conf
    model_conf = Dict(OmegaConf.to_container(model_conf))
    if model_conf.name.upper() == 'PNA':
        d = get_degree(data.adj_t, undirected=True)
        deg = torch.bincount(d)
        model_conf['deg'] = deg
    if model_conf.name.upper() == 'ACMGCN':
        model_conf['num_nodes'] = data.num_nodes
    if model_conf.name.upper() == 'POLYNORMER':
        if not hasattr(model_conf, 'global_dropout') or model_conf.global_dropout is None:
            model_conf['global_dropout'] = model_conf['dropout']
        if not hasattr(model_conf, 'init_dropout') or model_conf.init_dropout is None:
            model_conf['global_dropout'] = model_conf['dropout']

    ## full gpu training
    device = torch.device('cuda:{}'.format(conf.gpu) if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    metric.to(device)
    data.to(device)

    ## plot convergence and generalization curve
    curve_dict = {'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': [], 'iteration': []}

    ## log total runtime
    total_train_time = 0.

    ## only used for collecting ensemble results
    total_logit = []
    for k in range(1, conf.ens + 1):

        best_train, best_val, best_test, runs_logit = [], [], [], []
        for i in range(1, conf.runs + 1):

            log.info(f'------------------------Run {i}------------------------')
            setup_seed(i*k)

            ## dataset split
            if len(data.train_mask.shape) > 1:
                train_mask = data.train_mask[:, i - 1]
                val_mask = data.val_mask[:, i - 1]
                test_mask = data.test_mask[:, i - 1]
            else:
                train_mask = data.train_mask
                val_mask = data.val_mask
                test_mask = data.test_mask

            model = get_model(model_conf, num_features, num_classes).to(device)

            if model_conf.name.upper() in ['APPNP', 'GPRGNN']:
                optimizer = torch.optim.Adam([{
                    'params': model.lin1.parameters(),
                    'weight_decay': train_conf.weight_decay, 'lr': train_conf.lr
                }, {
                    'params': model.lin2.parameters(),
                    'weight_decay': train_conf.weight_decay, 'lr': train_conf.lr
                }, {
                    'params': model.prop1.parameters(),
                    'weight_decay': 0.0, 'lr': train_conf.lr
                }], lr=train_conf.lr)
            else:
                if train_conf.use_adamw:
                    optimizer = torch.optim.AdamW(
                        model.parameters(), lr=train_conf.lr, weight_decay=train_conf.weight_decay
                    )
                else:
                    optimizer = torch.optim.Adam(
                        model.parameters(), lr=train_conf.lr, weight_decay=train_conf.weight_decay
                    )

            ## for polynormer
            if hasattr(train_conf, 'local_epochs') and hasattr(train_conf, 'global_epochs'):
                train_conf.epoch = train_conf.local_epochs + train_conf.global_epochs

            best_val_acc, val_loss_history = 0, []
            for epoch in range(1, train_conf.epoch + 1):

                ## for polynormer
                if hasattr(train_conf, 'local_epochs') and epoch == train_conf.local_epochs + 1:
                    print("start global attention!!!!!!")
                    ckpt = torch.load(ckpt_path)
                    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
                    model.load_state_dict(ckpt['model_state_dict'])
                    model._global = True

                ## increase lr as epoch grows, can improve convergence in most cases
                if model_conf.name.upper() not in ['POLYNORMER']:
                    adjust_learning_rate(optimizer, train_conf.lr, epoch)

                tik = time.time()

                train(model, optimizer, data, train_mask,
                      grad_norm=train_conf.grad_norm,
                      use_label=train_conf.use_label,
                      mask_rate=train_conf.train_mask_rate)

                tok = time.time()
                total_train_time += tok - tik

                (train_acc, val_acc, test_acc), (train_loss, val_loss, test_loss), _ = test(
                    model, metric, data, [train_mask, val_mask, test_mask],
                    use_label=train_conf.use_label)

                curve_dict['train_acc'].append(train_acc)
                curve_dict['test_acc'].append(test_acc)
                curve_dict['train_loss'].append(train_loss)
                curve_dict['test_loss'].append(test_loss)
                curve_dict['iteration'].append(epoch)

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_acc': train_acc,
                        'val_acc': val_acc,
                    }, ckpt_path)

                log.info(f'Epoch: {epoch:03d}, Train Loss: {train_loss: .4f}, Val Loss: {val_loss: .4f} '
                         f'Train Acc: {train_acc:.4f}, Val Acc: {best_val_acc:.4f}')

                if epoch >= 0:
                    val_loss_history.append(val_loss)
                    if 0 < train_conf.early_stopping < epoch:
                        tmp = torch.tensor(val_loss_history[-(train_conf.early_stopping+1): -1])
                        if val_loss > tmp.mean().item():
                            break

            ckpt = torch.load(ckpt_path)
            if not conf.get('log_tuning', False):
                model = get_model(model_conf, num_features, num_classes).to(device)
                model.load_state_dict(ckpt['model_state_dict'])
            (test_acc,), (test_loss,), logit = test(model, metric, data, [test_mask],
                                                    use_label=train_conf.use_label)
            log.info(f"[Best Model] Epoch: {ckpt['epoch']:02d}, Train: {ckpt['train_acc']:.4f}, "
                  f"Val: {ckpt['val_acc']:.4f}, Test: {test_acc:.4f}")

            best_train.append(float(ckpt['train_acc']))
            best_val.append(float(best_val_acc))
            best_test.append(float(test_acc))
            runs_logit.append(logit.cpu())

        log.critical(f'Train: {np.mean(best_train)*100:.2f} +- {np.std(best_train)*100:.2f}')
        log.critical(f'Valid: {np.mean(best_val)*100:.2f} +- {np.std(best_val)*100:.2f}')
        log.critical(f'Test: {np.mean(best_test)*100:.2f} +- {np.std(best_test)*100:.2f}')

        total_logit.append(torch.stack(runs_logit, dim=0))

        if hasattr(conf, "log_file") and conf.log_file is not None:
            with open(conf.log_file, 'a') as file:
                file.write(f'{model_conf.name.upper()}-{data_conf.ptb_ratio}: '
                           f'{np.mean(best_test):.4f}\n')

    """Log"""
    model_name = model_conf.name.upper()
    if model_name == 'G2GNN':
        model_name = 'G2' + model_conf.conv_type
    if model_conf.jk is not None:
        model_name += '-jk{}'.format(model_conf.jk.upper())
    if model_conf.residual is not None:
        model_name += '-res{}'.format(model_conf.residual.upper())
    if model_conf.dropout > 0:
        model_name += '-dropout{}'.format(model_conf.dropout)
    if model_conf.dropedge > 0:
        model_name += '-dropedge{}'.format(model_conf.dropedge)
    if model_conf.init_layers > 0:
        model_name += '-init{}'.format(model_conf.init_layers)
    if conf.get('log_tuning', False):
        model_name += '-dim{}'.format(model_conf.hidden_dim)
        model_name += '-epoch{}'.format(train_conf.epoch)

    if conf.log_logit:
        filename = '{}_{}_conv{}'.format(conf.dataset.name, model_name, model_conf.conv_layers)
        if hasattr(data_conf, 'rev_adj') and data_conf.rev_adj:
            filename += '_revADJ'
        if data_conf.ptb_type is not None and data_conf.ptb_ratio > 0:
            filename += '_ptb{}{}'.format(data_conf.ptb_type.upper(), data_conf.ptb_ratio)
        if conf.get('log_curve', False):
            filename += '_curve'
        filename += '.pt'
        if conf.ens > 1:
            filename = filename[:-3] + f'_ens{conf.ens}.pt'
            torch.save(torch.stack(total_logit, dim=0), osp.join(proc_dir, filename))
        else:
            torch.save(total_logit[0], osp.join(proc_dir, filename))
        print('Logits saved in:', proc_dir + '/' + filename)

    if conf.get('log_curve', False):
        filename = '{}_{}_conv{}'.format(conf.dataset.name, model_name, model_conf.conv_layers)
        filename += '.pt'
        curve_dict = {key: torch.tensor(value) for key, value in curve_dict.items()}
        torch.save(curve_dict, osp.join(curve_dir, filename))

    if conf.get('log_time', False):
        log.critical(f'Total Train Time: {total_train_time/conf.runs:.2f}s')


if __name__ == '__main__':
    main()
