#! /usr/bin/env python
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0

import sys
import os
import random
import torch
import torch.nn.functional as F
# import git
import numpy as np
# import wandb
import networkx as nx
from tqdm import tqdm

# This is required here by wandb sweeps.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from exp.parser import get_parser
from models.positional_encodings import append_top_k_evectors
from models.cont_models import DiagSheafDiffusion, BundleSheafDiffusion, GeneralSheafDiffusion
from models.disc_models import DiscreteDiagSheafDiffusion, DiscreteBundleSheafDiffusion, DiscreteGeneralSheafDiffusion
from utils.heterophilic import get_dataset, get_fixed_splits


def reset_wandb_env():
    exclude = {
        "WANDB_PROJECT",
        "WANDB_ENTITY",
        "WANDB_API_KEY",
    }
    for k, v in os.environ.items():
        if k.startswith("WANDB_") and k not in exclude:
            del os.environ[k]


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x)[data.train_mask]
    nll = F.nll_loss(out, data.y[data.train_mask])
    loss = nll
    loss.backward()

    optimizer.step()
    del out


def test(model, data):
    model.eval()
    with torch.no_grad():
        logits, accs, losses, preds = model(data.x), [], [], []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

            loss = F.nll_loss(logits[mask], data.y[mask])

            preds.append(pred.detach().cpu())
            accs.append(acc)
            losses.append(loss.detach().cpu())
        return accs, preds, losses


def get_model(args, edge_index, model_cls):


    model = model_cls(edge_index, args)
    model = model.to(args['device'])


    return model

def fetch_model(nnodes, num_features, num_classes, adj_matrix, filter_num=None, dropout=None, reg_lambda=None, lr=None, *args):
    parser = get_parser()
    p_args = parser.parse_args()

    # repo = git.Repo('~/github/icml-2445/resolvnet/',search_parent_directories=True)
    # sha = repo.head.object.hexsha

    if p_args.model == 'DiagSheafODE':
        model_cls = DiagSheafDiffusion
    elif p_args.model == 'BundleSheafODE':
        model_cls = BundleSheafDiffusion
    elif p_args.model == 'GeneralSheafODE':
        model_cls = GeneralSheafDiffusion
    elif p_args.model == 'DiagSheaf':
        model_cls = DiscreteDiagSheafDiffusion
    elif p_args.model == 'BundleSheaf':
        model_cls = DiscreteBundleSheafDiffusion
    elif p_args.model == 'GeneralSheaf':
        model_cls = DiscreteGeneralSheafDiffusion
    else:
        raise ValueError(f'Unknown model {p_args.model}')



    # Add extra arguments
    if(filter_num!=None):
        p_args.hidden_channels=filter_num
    if(dropout!=None):
        p_args.dropout=dropout
    if(reg_lambda!=None):
        p_args.weight_decay=reg_lambda
    if(lr!=None):
        p_args.lr=lr
    p_args.sha = None
    p_args.graph_size = nnodes
    p_args.input_dim = num_features
    p_args.output_dim = num_classes
    p_args.device = torch.device(f'cuda:{p_args.cuda}' if torch.cuda.is_available() else 'cpu')
    assert p_args.normalised or p_args.deg_normalised
    if p_args.sheaf_decay is None:
        p_args.sheaf_decay = p_args.weight_decay

    # Set the seed for everything
    torch.manual_seed(p_args.seed)
    torch.cuda.manual_seed(p_args.seed)
    torch.cuda.manual_seed_all(p_args.seed)
    np.random.seed(p_args.seed)
    random.seed(p_args.seed)

    results = []


    G = nx.from_numpy_array(adj_matrix.todense())
    edge_index = np.matrix(G.edges()).transpose()
    edge_index = torch.from_numpy(edge_index).to(p_args.device)

    for fold in range(p_args.folds):
        model = get_model(vars(p_args), edge_index, model_cls)
        return model
        
