import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import random
import os
import sys
from tqdm import tqdm
import time
import dgl
from scipy.sparse import vstack as s_vstack
from sklearn.preprocessing import StandardScaler

from utils import args_parser
from utils import metric
import utils.data_load as dl

from model.layer import FC, Wrap_Embedding

from utils.sampler import HypergraphNeighborSampler
from model.CoNHD_GD import CoNHD_GD, CoNHD_GD_Layer, CoNHDDiffScorer
from model.CoNHD_ADMM import CoNHD_ADMM, CoNHD_ADMM_Layer

def run_epoch(args, data, dataloader, initembedder, embedder, scorer, optim, scheduler, loss_fn, opt="train"):
    total_pred = []
    total_label = []
    num_data = 0
    total_loss = 0
    
    # Batch ==============================================================
    ts = time.time()
    batchcount = 0
    for input_nodes, output_nodes, blocks in dataloader:      
        # Wrap up loader
        blocks = [b.to(device) for b in blocks]
        srcs, dsts = blocks[-1].edges(etype='in')
        nodeindices = srcs.to(device)
        hedgeindices = dsts.to(device)
        nodelabels = blocks[-1].edges[('node','in','edge')].data['diff_feat'].to(device)
        
        batchcount += 1
        # Get Embedding
        if args.embedder == "CoNHD_GD" or args.embedder == "CoNHD_ADMM": 
            v_feat = initembedder(input_nodes['node'].to(device))
            co_feat, co_eid = embedder(blocks[(len(blocks)//2):], v_feat)
                
        # Predict Class
        if args.scorer == "im":
            if args.embedder == "CoNHD_GD" or args.embedder == "CoNHD_ADMM": 
                predictions, nodelabels, _, _ = scorer(blocks[-1], co_feat, co_eid)
            else: 
                predictions, nodelabels, _, _ = scorer(blocks[-1], v, e)
        total_pred.append(predictions.detach())
        total_label.append(nodelabels.detach())
        
        # Back Propagation
        num_data += predictions.shape[0]
        loss = loss_fn(predictions, nodelabels)
        if opt == "train":
            optim.zero_grad()
            loss.backward() 
            optim.step()
        total_loss += (loss.item() * predictions.shape[0])
        if opt == "train":
            torch.cuda.empty_cache()
    
    print("Time : ", time.time() - ts)
    
    return total_pred, total_label, total_loss / num_data, embedder, scorer, optim, scheduler

def run_test_epoch(args, data, testdataloader, initembedder, embedder, scorer, loss_fn):
    total_pred = []
    total_label = []
    num_data = 0
    total_loss = 0
    
    # Batch ==============================================================
    ts = time.time()
    batchcount = 0
    for input_nodes, output_nodes, blocks in testdataloader:      
        # Wrap up loader
        blocks = [b.to(device) for b in blocks]
        srcs, dsts = blocks[-1].edges(etype='in')
        nodeindices = srcs.to(device)
        hedgeindices = dsts.to(device)
        nodelabels = blocks[-1].edges[('node','in','edge')].data['diff_feat'].to(device)
        
        batchcount += 1
        # Get Embedding
        if args.embedder == "CoNHD_GD" or args.embedder == "CoNHD_ADMM": 
            v_feat = initembedder(input_nodes['node'].to(device))
            co_feat, co_eid = embedder(blocks[(len(blocks)//2):], v_feat)
                
        # Predict Class
        if args.scorer == "im":
            if args.embedder == "CoNHD_GD" or args.embedder == "CoNHD_ADMM": 
                predictions, nodelabels, _, _ = scorer(blocks[-1], co_feat, co_eid)
            else: 
                predictions, nodelabels, _, _ = scorer(blocks[-1], v, e)
        total_pred.append(predictions.detach())
        pred_cls = torch.argmax(predictions, dim=1)
        total_label.append(nodelabels.detach())
        
        num_data += predictions.shape[0]
        loss = loss_fn(predictions, nodelabels)

        total_loss += (loss.item() * predictions.shape[0])
        
    return total_pred, total_label, total_loss / num_data, embedder, scorer

# Make Output Directory --------------------------------------------------------------------------------------------------------------
args = args_parser.parse_args()
if args.task_type is None: 
    args.task_type = 'fit_diffusion_from_node_feat'
if args.evaltype == "test":
    outputdir = "results_test/" + args.task_type + "/" + args.dataset_name + "/" + \
        f'optim_{args.optim}_hedge_reg_{args.hedge_reg}_node_reg_{args.node_reg}_dim_{args.feat_dim}' + "/"
    outputdir += args.model_name + "/" + args.param_name +"/" + str(args.seed) + "/"
    if args.recalculate is False and os.path.isfile(outputdir + "log_test_confusion.txt"):
        sys.exit("Already Run")
else:
    outputdir = "results/" + args.task_type + "/" + args.dataset_name + "/" + \
        f'optim_{args.optim}_hedge_reg_{args.hedge_reg}_node_reg_{args.node_reg}_dim_{args.feat_dim}' + "/"
    outputdir += args.model_name + "/" + args.param_name +"/"
    if args.recalculate is False and os.path.isfile(outputdir + "log_test_confusion.txt"):
        sys.exit("Already Run")
if os.path.isdir(outputdir) is False:
    os.makedirs(outputdir)
print("OutputDir = " + outputdir)

# Initialization --------------------------------------------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dataset_name = args.dataset_name #'citeseer' 'cora'

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
dgl.seed(args.seed)

valid_epoch = args.valid_epoch

if os.path.isfile(outputdir + "checkpoint.pt") and args.recalculate is False:
    print("Start from checkpoint")
elif (args.recalculate is False and args.evaltype == "valid") and os.path.isfile(outputdir + "log_valid_micro.txt"):
    max_acc = 0
    cur_patience = 0
    epoch = 0
    with open(outputdir + "log_valid_micro.txt", "r") as f:
        for line in f.readlines():
            ep_str = line.rstrip().split(":")[0].split(" ")[0]
            acc_str = line.rstrip().split(":")[-1]
            epoch = int(ep_str)
            if max_acc < float(acc_str):
                cur_patience = 0
                max_acc = float(acc_str)
            else:
                cur_patience += 1
            if cur_patience > args.patience:
                break
    if cur_patience > args.patience or epoch == args.epochs:
        sys.exit("Already Run by log valid micro txt")
else:
    if os.path.isfile(outputdir + "log_train.txt"):
        os.remove(outputdir + "log_train.txt")
    if os.path.isfile(outputdir + "log_valid_micro.txt"):
        os.remove(outputdir + "log_valid_micro.txt")
    if os.path.isfile(outputdir + "log_valid_confusion.txt"):
        os.remove(outputdir + "log_valid_confusion.txt")
    if os.path.isfile(outputdir + "log_valid_macro.txt"):
        os.remove(outputdir + "log_valid_macro.txt")
    if os.path.isfile(outputdir + "log_test_micro.txt"):
        os.remove(outputdir + "log_test_micro.txt")
    if os.path.isfile(outputdir + "log_test_confusion.txt"):
        os.remove(outputdir + "log_test_confusion.txt")
    if os.path.isfile(outputdir + "log_test_macro.txt"):
        os.remove(outputdir + "log_test_macro.txt")
        
    if os.path.isfile(outputdir + "embedder.pt"):
        os.remove(outputdir + "embedder.pt")
    if os.path.isfile(outputdir + "scorer.pt"):
        os.remove(outputdir + "scorer.pt")
    if os.path.isfile(outputdir + "evaluation.txt"):
        os.remove(outputdir + "evaluation.txt")
            
# Data -----------------------------------------------------------------------------
data = dl.Hypergraph(args)
train_data = data.get_data(0)
valid_data = data.get_data(1)
if args.evaltype == "test":
    test_data = data.get_data(2)
ls = [{('node', 'in', 'edge'): args.node_sampling, ('edge', 'con', 'node'): args.hedge_sampling}] * (args.num_layers * 2 + 1) + \
    [{('node', 'in', 'edge'): -1, ('edge', 'con', 'node'): -1}]  # do not sample the last layer (node-hedge pairs)
full_ls = [{('node', 'in', 'edge'): -1, ('edge', 'con', 'node'): -1}] * (args.num_layers * 2 + 1 + 1)
if data.weight_flag:
    g = dl.gen_weighted_DGLGraph_with_diff_feat(data.hedge2node, data.hedge2nodePE,  
                                                data.node_hedge_indexes_dict, data.diff_feat)
else:
    g = dl.gen_DGLGraph_with_diff_feat(data.hedge2node, 
                                       data.node_hedge_indexes_dict, data.diff_feat)

fullsampler = HypergraphNeighborSampler(full_ls)
sampler = HypergraphNeighborSampler(ls)

if args.use_gpu:
    g = g.to(device)
    train_data = train_data.to(device)
    valid_data = valid_data.to(device)
    if args.evaltype == "test":
        test_data = test_data.to(device)
    data.e_feat = data.e_feat.to(device)
    
# ensure corresponding edges in 'in' and 'con' have same ids
assert (g.edge_ids(*g.edges(etype='in'), etype='in') 
        == g.edge_ids(*reversed(g.edges(etype='in')), etype='con')).all(), \
        "corresponding \'in\' and \'con\' edges should have the same ids! "

dataloader = dgl.dataloading.NodeDataLoader( g, {"edge": train_data}, sampler, 
                                            batch_size=len(train_data) if args.bs==-1 else args.bs, 
                                            shuffle=False, drop_last=False) # , num_workers=4
validdataloader = dgl.dataloading.NodeDataLoader(g, {"edge": valid_data}, sampler, 
                                                 batch_size=len(valid_data) if args.bs==-1 else args.bs, 
                                                 shuffle=False, drop_last=False)
if args.evaltype == "test":
    testdataloader = dgl.dataloading.NodeDataLoader(g, {"edge": test_data}, fullsampler, 
                                                    batch_size=len(test_data) if args.bs==-1 else args.bs, 
                                                    shuffle=False, drop_last=False)

args.input_edim = data.e_feat.size(1)
args.order_dim = data.order_dim

# init embedder
args.input_vdim = args.feat_dim
init_feat_fname = os.path.join(args.data_inputdir, args.task_type, args.dataset_name, f'node_origin_feat_dim_{args.feat_dim}.npy')
feat_origin_index_fname = os.path.join(args.data_inputdir, args.task_type, args.dataset_name, f'node_origin_indexes.txt')
node_list = np.arange(data.numnodes).astype('int')
print("load exist init features: ")
print(init_feat_fname, flush=True)
A = np.load(init_feat_fname)
# Sort feature matrix according to node indexes
with open(feat_origin_index_fname, 'r') as f: 
    feat_origin_indexes = f.readlines()
    feat_origin_indexes = dict(zip([index.strip() for index in feat_origin_indexes], list(np.arange(len(feat_origin_indexes)))))
sort_indexes = np.array([feat_origin_indexes[data.node_orgindex[i]] for i in range(data.numnodes)])
A = A[sort_indexes, :]
# feature transform
A = StandardScaler().fit_transform(A)  ##
A = A.astype('float32')
A = torch.tensor(A).to(device)
initembedder = Wrap_Embedding(data.numnodes, args.input_vdim, scale_grad_by_freq=False, padding_idx=0, sparse=False)
initembedder.weight = nn.Parameter(A)

for param in initembedder.parameters():
    if not param.requires_grad:
        continue
    param.requires_grad_(False)

print("Model:", args.embedder)
# model init
if args.embedder == "CoNHD_GD": 
    embedder = CoNHD_GD(CoNHD_GD_Layer, args.PE_Block, args.input_vdim, args.co_rep_dim, 
                        weight_dim=args.order_dim, num_layers=args.num_layers, num_heads=args.num_heads, 
                        num_inds=args.num_inds, att_type_v=args.att_type_v, att_type_e=args.att_type_e, 
                        num_att_layer=args.num_att_layer, dropout=args.dropout, input_dropout=args.input_dropout, 
                        weight_flag=data.weight_flag, layernorm=args.layernorm, 
                        node_agg=args.node_agg, hedge_agg=args.hedge_agg).to(device)
elif args.embedder == "CoNHD_ADMM": 
    embedder = CoNHD_ADMM(CoNHD_ADMM_Layer, args.PE_Block, args.input_vdim, args.co_rep_dim, 
                          weight_dim=args.order_dim, num_layers=args.num_layers, num_heads=args.num_heads, 
                          num_inds=args.num_inds, att_type_v=args.att_type_v, att_type_e=args.att_type_e, 
                          num_att_layer=args.num_att_layer, dropout=args.dropout, input_dropout=args.input_dropout, 
                          weight_flag=data.weight_flag, layernorm=args.layernorm).to(device)

    
print("Embedder to Device")
print("Scorer = ", args.scorer)
# pick scorer
if args.scorer == "sm":
    scorer = FC(args.dim_vertex + args.dim_edge, args.dim_edge, args.feat_dim, args.scorer_num_layers).to(device)
elif args.scorer == "im": #whatsnet
    if args.embedder == "CoNHD_GD" or args.embedder == "CoNHD_ADMM": 
        scorer = CoNHDDiffScorer(args.feat_dim, args.co_rep_dim, dim_hidden=args.dim_hidden, num_layers=args.scorer_num_layers).to(device)
    else:
        sys.exit("Not Implemented scorer")

if args.optimizer == "adam":
    optim = torch.optim.Adam(list(embedder.parameters())+list(scorer.parameters()), lr=args.lr) #, weight_decay=args.weight_decay)
elif args.optimizer == "adamw":
    optim = torch.optim.AdamW(list(embedder.parameters())+list(scorer.parameters()), lr=args.lr)
elif args.optimizer == "rms":
    optime = torch.optim.RMSprop(list(embedder.parameters())+list(scorer.parameters()), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=args.gamma)
loss_fn = nn.L1Loss()

# Train =================================================================================================================================================================================
patience = 0
best_eval_mae = 100000000
best_eval_epoch = 0
epoch_start = 1
if os.path.isfile(outputdir + "checkpoint.pt") and args.recalculate is False:
    checkpoint = torch.load(outputdir + "checkpoint.pt") #, map_location=device)
    epoch_start = checkpoint['epoch'] + 1
    embedder.load_state_dict(checkpoint['embedder'])
    scorer.load_state_dict(checkpoint['scorer'])
    optim.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    best_eval_mae = checkpoint['best_eval_mae']
    best_eval_epoch = checkpoint['best_eval_epoch']
    patience = checkpoint['patience']    
    
    print("Load {} epoch trainer".format(epoch_start))
    print("best_eval_acc = {}\tpatience = {}".format(best_eval_mae, patience))
    

with open(outputdir + "log_train.txt", "+a") as f:
    f.write(time.strftime('%Y-%m-%d %H:%M:%S, ',time.localtime(time.time())))
    f.write("Training start, epoch: %d\n" % (epoch_start))
epoch = epoch_start
for epoch in tqdm(range(epoch_start, args.epochs + 1), desc='Epoch'): # tqdm
    
    # break if patience from checkpoint is larger than the threshold
    if patience > args.patience:
        break
    
    print("Training")
    
    # Training stage
    embedder.train()
    scorer.train()
    
    # Calculate Accuracy & Epoch Loss
    total_pred, total_label, train_loss, embedder, scorer, optim, scheduler = run_epoch(args, data, dataloader, initembedder, embedder, scorer, optim, scheduler, loss_fn, opt="train")
    total_pred = torch.cat(total_pred)
    total_label = torch.cat(total_label, dim=0)
    total_mse = torch.nn.functional.mse_loss(total_pred, total_label)
    total_mae = torch.nn.functional.l1_loss(total_pred, total_label)
    scheduler.step()
    print("%d epoch: Training loss : %.4f / Training mse : %.4f / Training mae : %.4f\n" % (epoch, train_loss, total_mse, total_mae), flush=True)
    with open(outputdir + "log_train.txt", "+a") as f:
        f.write(time.strftime('%Y-%m-%d %H:%M:%S, ',time.localtime(time.time())))
        f.write("%d epoch: Training loss : %.4f / Training mse : %.4f / Training mae : %.4f\n" % (epoch, train_loss, total_mse, total_mae))
        
    # Test ===========================================================================================================================================================================
    if epoch % valid_epoch == 0:
        embedder.eval()
        scorer.eval()
        
        with torch.no_grad():
            total_pred, total_label, eval_loss, embedder, scorer, optim, scheduler = run_epoch(args, data, validdataloader, initembedder, embedder, scorer, optim, scheduler, loss_fn, opt="valid")
        # Calculate Accuracy & Epoch Loss
        total_label = torch.cat(total_label, dim=0)
        total_pred = torch.cat(total_pred)
        total_mse = torch.nn.functional.mse_loss(total_pred, total_label)
        total_mae = torch.nn.functional.l1_loss(total_pred, total_label)
        print("Valid {} epoch:Test Loss:{} / mse:{} / mae:{}\n".format(epoch, eval_loss, total_mse, total_mae), flush=True)
        with open(outputdir + "log_valid.txt", "+a") as f:
            f.write(time.strftime('%Y-%m-%d %H:%M:%S, ',time.localtime(time.time())))
            f.write("{} epoch:Test Loss:{} / mse:{} / mae:{}\n".format(epoch, eval_loss, total_mse, total_mae))
        
        if best_eval_mae > total_mae:
            best_eval_mae = total_mae
            best_eval_epoch = epoch
            print('Best Eval MAE: %f' % best_eval_mae)
            patience = 0
            if args.evaltype == "test" or args.save_best_epoch:
                print("Model Save")
                modelsavename = outputdir + "embedder.pt"
                torch.save(embedder.state_dict(), modelsavename)
                scorersavename = outputdir + "scorer.pt"
                torch.save(scorer.state_dict(), scorersavename)
        else:
            patience += 1

        if patience > args.patience:
            break
        
        torch.save({
            'epoch': epoch,
            'embedder': embedder.state_dict(),
            'scorer' : scorer.state_dict(),
            'scheduler' : scheduler.state_dict(),
            'optimizer': optim.state_dict(),
            'best_eval_mae' : best_eval_mae,
            'best_eval_epoch': best_eval_epoch, 
            'patience' : patience
            }, outputdir + "checkpoint.pt")

if args.evaltype == "test":
    print("Test")
    print(f"best eval epoch: {best_eval_epoch}")
    
    embedder.load_state_dict(torch.load(outputdir + "embedder.pt")) # , map_location=device
    scorer.load_state_dict(torch.load(outputdir + "scorer.pt")) # , map_location=device
    
    embedder.eval()
    scorer.eval()

    with torch.no_grad():
        total_pred, total_label, test_loss, embedder, scorer = run_test_epoch(args, data, testdataloader, initembedder, embedder, scorer, loss_fn)
    # Calculate Accuracy & Epoch Loss
    total_label = torch.cat(total_label, dim=0)
    total_pred = torch.cat(total_pred)
    total_mse = torch.nn.functional.mse_loss(total_pred, total_label)
    total_mae = torch.nn.functional.l1_loss(total_pred, total_label)
    
    print("Test {} epoch, {} best eval epoch:Test Loss:{} / mse:{}/ mae:{}\n".format(epoch, best_eval_epoch, test_loss, total_mse, total_mae), flush=True)
    with open(outputdir + "log_test.txt", "+a") as f:
        f.write(time.strftime('%Y-%m-%d %H:%M:%S, ',time.localtime(time.time())))
        f.write("{} epoch, {} best eval epoch:Test Loss:{} / mse:{}/ mae:{}\n".format(epoch, best_eval_epoch, test_loss, total_mse, total_mae))

if os.path.isfile(outputdir + "checkpoint.pt"):
    os.remove(outputdir + "checkpoint.pt")

