import torch
import numpy as np
import time
import copy
import argparse
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from gnn_models import GCN, GIN_Net, APPNP_Net
from sklearn.model_selection import KFold
from utils import preprocess
import torch.nn.functional as F
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import warnings
warnings.filterwarnings("ignore")

class EarlyStopping:
    def __init__(self, tolerance=10, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True


def train(model, criterion, optimizer, train_loader, device):
    model.train()
    # loss_criterion = torch.nn.CrossEntropyLoss()
    train_loss = []
    for data in train_loader:  # Iterate in batches over the training dataset.
        # data = data.to(device)
        out = model(data)  # Perform a single forward pass.
        # print(criterion(out, data.y).shape)
        # loss = (criterion(out, data.y) * weight).sum()
        
        loss = criterion(out, data.y)
        train_loss.append(loss.detach().item())
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    
    return np.mean(train_loss)



# def weighted_train(model, criterion, optimizer, train_loader, device):
#     model.train()
#     # loss_criterion = torch.nn.CrossEntropyLoss()
#     train_loss = []
#     for data, weight in train_loader:  # Iterate in batches over the training dataset.
#         data, weight = data.to(device), weight.to(device)
#         out = model(data)  # Perform a single forward pass.
#         # print(criterion(out, data.y).shape)
#         loss = (criterion(out, data.y) * weight).sum()
        
#         # loss = criterion(out, data.y)
#         train_loss.append(loss.detach().item())
#         loss.backward()  # Derive gradients.
#         optimizer.step()  # Update parameters based on gradients.
#         optimizer.zero_grad()  # Clear gradients.
    
#     return np.mean(train_loss)

# def eval_weighted_train(model, loader, device):
#     model.eval()

#     correct = 0
#     # n = 0
#     for data, weight in loader:  # Iterate in batches over the training/test dataset.
#         data, weight = data.to(device), weight.to(device)
#         out = model(data)
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
#         # gnd = data.y.argmax(dim=1)
#         # n = n + len(gnd)
#         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
#     return correct / len(loader.dataset)  # Derive ratio of correct predictions.



# def eval_val_and_test(model, loader, device):
#     model.eval()

#     correct = 0
#     # n = 0
#     for data in loader:  # Iterate in batches over the training/test dataset.
#         # data = data.to(device)
#         out = model(data)
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
#         # gnd = data.y.argmax(dim=1)
#         # n = n + len(gnd)
#         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
#     return correct / len(loader.dataset)  # Derive ratio of correct predictions.

from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score

def eval_train(model, loader, device):
    model.eval()

    correct = 0
    # n = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        # data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        # gnd = data.y.argmax(dim=1)
        # n = n + len(gnd)
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


def eval_val_and_test(model, loader, device):
    model.eval()

    correct = 0
    # n = 0
    pred_list = []
    y_list = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        # data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        # gnd = data.y.argmax(dim=1)
        # n = n + len(gnd)
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
        pred_list.extend(pred.cpu())
        y_list.extend(data.y.cpu())
    
    return correct / len(loader.dataset), balanced_accuracy_score(y_list, pred_list) # Derive ratio of correct predictions.




def train_loop(model, epochs, train_loader, val_loader, test_loader, fold_num, device):

    # reset model params
    model.reset_parameters()
    # some gnn training operators
    train_criterion = torch.nn.CrossEntropyLoss()
    val_test_criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # scheduler = StepLR(optimizer, step_size=100, gamma=0.5)
    early_stopping = EarlyStopping(tolerance=50, min_delta=0.1)


    # criteria: best_val, last epoch, 
    output_dict = {}
    best_val_acc = 0
    best_val_test = eval_val_and_test(model, test_loader, device)
    # print("initial test acc: ", best_val_test)
    for epoch in tqdm(range(1, epochs + 1)):
        train_loss = train(model, train_criterion, optimizer, train_loader, device)
        train_acc = eval_train(model, train_loader, device)
        # scheduler.step()

        with torch.no_grad(): 
            val_acc = eval_val_and_test(model, val_loader, device)[0]
        t4 = time.time()
        early_stopping(train_acc, val_acc)

        
        if early_stopping.early_stop:
            test_acc = eval_val_and_test(model, test_loader, device)[0]
            print(f'Early breaking!')
            print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
            break
        if epoch % 100 == 0:
            test_acc = eval_val_and_test(model, test_loader, device)[0]
            print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # best_val_epoch = epoch
            # best_model = copy.deepcopy(model)
            best_val_test = eval_val_and_test(model, test_loader, device)[0]
    
    last_epoch_acc = eval_val_and_test(model, test_loader, device)[0]

    output_dict["test_acc_for_best_val"] = best_val_test
    output_dict["test_acc_for_last_epoch"] = last_epoch_acc
    
    return output_dict
    # currently return the model that performs the best on val set
    # return eval_val_and_test(best_model, test_loader, device), best_val_epoch



# def ogb_eval_train(model, loader, device):
#     model.eval()

#     correct = 0
#     # n = 0
#     pred_list = []
#     y_list = []
#     for data in loader:  # Iterate in batches over the training/test dataset.
#         # data = data.to(device)
#         out = model(data)
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
#         # gnd = data.y.argmax(dim=1)
#         # n = n + len(gnd)
#         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
#         pred_list.extend(pred.cpu())
#         y_list.extend(data.y.cpu())
#     # acc
#     return correct / len(loader.dataset), roc_auc_score(y_list, pred_list)  # Derive ratio of correct predictions.


# def ogb_eval_val_and_test(model, loader, device):
#     model.eval()

#     correct = 0
#     # n = 0
#     pred_list = []
#     y_list = []
#     for data in loader:  # Iterate in batches over the training/test dataset.
#         # data = data.to(device)
#         out = model(data)
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
#         # gnd = data.y.argmax(dim=1)
#         # n = n + len(gnd)
#         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
#         pred_list.extend(pred.cpu())
#         y_list.extend(data.y.cpu())
    
#     return correct / len(loader.dataset), roc_auc_score(y_list, pred_list) # Derive ratio of correct predictions.



def ogb_train(model, criterion, optimizer, train_loader, device):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data, emb=True)
        loss = F.binary_cross_entropy_with_logits(out.view(-1), data.y.view(-1).float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Evaluation
@torch.no_grad()
def ogb_eval(model, loader, device, name):
    evaluator = Evaluator(name=name)
    model.eval()
    y_true, y_pred = [], []
    for data in loader:
        data = data.to(device)
        out = model(data, emb=True)
        y_true.append(data.y.view(-1, 1))
        y_pred.append(out)
    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    # print(y_true.shape, y_pred.shape)
    return evaluator.eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]

def ogb_train_loop(model, epochs, train_loader, val_loader, test_loader, fold_num, device, name):

    # reset model params
    model.reset_parameters()
    # some gnn training operators
    train_criterion = torch.nn.CrossEntropyLoss()
    val_test_criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
    # scheduler = StepLR(optimizer, step_size=100, gamma=0.5)
    early_stopping = EarlyStopping(tolerance=50, min_delta=0.1)


    # criteria: best_val, last epoch, 
    output_dict = {}
    best_val_acc = 0
    best_val_test = ogb_eval(model, test_loader, device, name)
    # print("initial test acc: ", best_val_test)
    # remove train acc to avoid glitch
    for epoch in tqdm(range(1, epochs + 1)):
        train_loss = ogb_train(model, train_criterion, optimizer, train_loader, device)
        # print(train_loss)
        # train_acc = ogb_eval(model, train_loader, device, name)
        # scheduler.step()

        with torch.no_grad(): 
            val_acc = ogb_eval(model, val_loader, device, name)
        t4 = time.time()
        # early_stopping(train_acc, val_acc)

        
        if early_stopping.early_stop:
            test_acc = ogb_eval(model, test_loader, device, name)
            print(f'Early breaking!')
            print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
            break
        if epoch % 20 == 0:
            test_acc = ogb_eval(model, test_loader, device, name)
            print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # best_val_epoch = epoch
            # best_model = copy.deepcopy(model)
            best_val_test = ogb_eval(model, test_loader, device, name)
    
    last_epoch_acc = ogb_eval(model, test_loader, device, name)

    output_dict["test_acc_for_best_val"] = best_val_test
    output_dict["test_acc_for_last_epoch"] = last_epoch_acc
    
    return output_dict
    # currently return the model that performs the best on val set
    # return eval_val_and_test(best_model, test_loader, device), best_val_epoch














# def bal_train_loop(model, epochs, train_loader, val_loader, test_loader, fold_num, device):

#     # reset model params
#     model.reset_parameters()
#     # some gnn training operators
#     train_criterion = torch.nn.CrossEntropyLoss()
#     val_test_criterion = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
#     # scheduler = StepLR(optimizer, step_size=100, gamma=0.5)
#     early_stopping = EarlyStopping(tolerance=50, min_delta=0.1)


#     # criteria: best_val, last epoch, 
#     output_dict = {}
#     best_val_acc = 0
#     best_bal_val_acc = 0
#     for epoch in range(1, epochs+1):
#         train_loss = train(model, train_criterion, optimizer, train_loader, device)
#         train_acc = eval_train(model, train_loader, device)
#         # scheduler.step()

#         with torch.no_grad(): 
#             val_acc, bal_val_acc = eval_val_and_test(model, val_loader, device)
#         t4 = time.time()
#         early_stopping(train_acc, bal_val_acc)

        
#         if early_stopping.early_stop:
#             test_acc, bal_test_acc = eval_val_and_test(model, test_loader, device)
#             print(f'Early breaking!')
#             print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Bal Val Acc: {best_val_acc:.4f}, Test Acc: {test_acc:.4f}, Bal Test Acc: {bal_test_acc:.4f}')
#             break
#         if epoch % 1 == 0:
#             test_acc, bal_test_acc = eval_val_and_test(model, test_loader, device)
#             print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Bal Val Acc: {best_val_acc:.4f}, Test Acc: {test_acc:.4f}, Bal Test Acc: {bal_test_acc:.4f}')
            
#         if bal_val_acc > best_bal_val_acc:
#             best_bal_val_acc = bal_val_acc
#             # best_val_epoch = epoch
#             best_model = copy.deepcopy(model)
    
#     # last_epoch_acc = eval_val_and_test(model, test_loader, device)

#     output_dict["test_acc_for_best_val"] = eval_val_and_test(best_model, test_loader, device)[1]
#     # output_dict["test_acc_for_last_epoch"] = last_epoch_acc
    
#     return output_dict


# def weighted_train_loop(model, epochs, train_loader, val_loader, test_loader, fold_num, device):

#     # to record the acc trace
#     acc = []
#     # reset model params
#     model.reset_parameters()
#     # some gnn training operators
#     train_criterion = torch.nn.CrossEntropyLoss()
#     val_test_criterion = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
#     scheduler = StepLR(optimizer, step_size=100, gamma=0.5)
#     early_stopping = EarlyStopping(tolerance=50, min_delta=0.1)

#     for epoch in range(1, epochs+1):
#         t1 = time.time()
#         train_loss = weighted_train(model, train_criterion, optimizer, train_loader, device)
#         t2 = time.time()
#         train_acc = eval_weighted_train(model, train_loader, device)
#         t3 = time.time()
#         scheduler.step()

#         with torch.no_grad(): 
#             val_acc = eval_val_and_test(model, val_loader, device)
#         t4 = time.time()
#         early_stopping(train_acc, val_acc)

        
#         if early_stopping.early_stop:
#             test_acc = eval_val_and_test(model, test_loader, device)
#             print(f'Early breaking!')
#             print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
#             break
#         if epoch % 20 == 0:
#             test_acc = eval_val_and_test(model, test_loader, device)
#             print(f'Fold: {fold_num+1:01d}, Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
        
#         # print(t2-t1, t3-t2, t4-t3)
#     # train_time.append(time.time() - ts)
#     # test_acc = eval_val_and_test(model, test_loader, device)
#     # acc.append(test_acc)


#     return eval_val_and_test(model, test_loader, device)