import os
import sys
import gc
import argparse
import torch
from torch.optim import Adam
import torch.nn as nn
import pandas as pd
import dgl
from dgl.dataloading import GraphDataLoader

from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from model.SAN_KAN.SAN_nodeLPE import SAN_NodeLPE
from model.SAN_KAN.module import laplace_decomp

def pyg_to_dgl(data):
    edge_index = data.edge_index
    src = edge_index[0].to('cpu')
    dst = edge_index[1].to('cpu')
    
    g = dgl.graph((src, dst))
    
    g.ndata['feat'] = data.x[:g.num_nodes(), :].to('cpu')
    g.edata['feat'] = data.edge_attr
    
    num_existing_nodes = g.num_nodes()
    num_total_nodes = data.x.size(0)
    
    if num_total_nodes > num_existing_nodes:
        num_new_nodes = num_total_nodes - num_existing_nodes
        
        g.add_nodes(num_new_nodes)
        
        g.ndata['feat'][num_existing_nodes:] = data.x[num_existing_nodes:, :].to('cpu')
    
    return g


class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss >= (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=128, help='Input batch size for training')
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train')
parser.add_argument('--model', type=str, default='KAA_SAN', help='model to test')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (1 - keep probability)')
parser.add_argument('--patience', type=int, default=20, help='Patience for ealry stopping')
parser.add_argument('--n-gnn-layers', type=int, default=4, help='Number of message passing layers')
parser.add_argument('--device_num', type=int, default=0, help='the device number')
parser.add_argument('--seed', type=int, default=1, help='the random seed')
parser.add_argument("--max_freqs", type=int, default=8, help="freq in lap decomp")
parser.add_argument("--num_heads", type=int, default=1, help="number of attention heads")
parser.add_argument('--hidden_dim', type=int, default=64)
args = parser.parse_args()

# random seed
random_seed = args.seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

device = torch.device('cuda:{}'.format(args.device_num) if torch.cuda.is_available() else 'cpu')

train_dataset = ZINC('./dataset/ZINC', subset=True, split='train')
val_dataset = ZINC('./dataset/ZINC', subset=True, split='val')
test_dataset = ZINC('./dataset/ZINC', subset=True, split='test')


def apply_pos_enc(dataset, max_freqs):
    """Applies Laplacian positional encoding to the entire dataset."""
    processed_data_list = []
    
    for data in dataset:
        temp_graph = pyg_to_dgl(data)
        temp_graph = laplace_decomp(temp_graph, max_freqs)
        data.eigvecs = temp_graph.ndata['EigVecs']
        data.eigvalues = temp_graph.ndata['EigVals']
        processed_data_list.append(data)
    
    return processed_data_list

train_data_processed = apply_pos_enc(train_dataset, args.max_freqs)
val_data_processed = apply_pos_enc(val_dataset, args.max_freqs)
test_data_processed = apply_pos_enc(test_dataset, args.max_freqs)

train_loader = DataLoader(train_data_processed, args.batch_size, shuffle=True)
val_loader = DataLoader(val_data_processed, args.batch_size, shuffle=False)
test_loader = DataLoader(test_data_processed, args.batch_size, shuffle=False)

net_params = {
    'kind': args.model,
    'in_dim': train_dataset.num_features,
    'GT_hidden_dim': args.hidden_dim,
    'GT_out_dim': args.hidden_dim,
    'GT_n_heads': args.num_heads,
    'dropout': 0,
    'in_feat_dropout': 0,
    'GT_layers': 2,
    'max_freqs': args.max_freqs,
    'layer_norm': False,
    'batch_norm': True,
    'residual': True,
    'full_graph': False,
    'device': torch.device('cuda:{}'.format(args.device_num) if torch.cuda.is_available() else 'cpu'),
    'gamma': 1,
    'LPE_dim': 8,
    'LPE_n_heads': 2,
    'LPE_layers': 2,
    'spline_order': 2,
    'grid_size': 1,
    'hidden_layers': 2,
    'n_classes': 1
}

LR = [0.01, 0.001]
HIDDEN_DIM = [16, 32, 64, 128]
N_LAYERS = [2, 4]
GRID_SIZE = [1]
SPLINE_ORDER = [2, 3]


best_val_mae = float('inf')
for lr in LR:
    for hidden_dim in HIDDEN_DIM:
        for n_layers in N_LAYERS:
            for grid_size in GRID_SIZE:
                for spline_order in SPLINE_ORDER:
                    net_params['GT_hidden_dim'] = hidden_dim
                    net_params['GT_out_dim'] = hidden_dim
                    net_params['GT_layers'] = n_layers
                    net_params['grid_size'] = grid_size
                    net_params['spline_order'] = spline_order

                    print('Evaluating the following hyperparameters:')
                    print('lr:', lr, 'hidden_dim:', hidden_dim, 'n_layers:', n_layers, 'grid_size:', grid_size,
                          'spline_order:', spline_order)
                    model = SAN_NodeLPE(net_params).to(net_params['device'])
                    optimizer = Adam(model.parameters(), lr=lr)


                    def train(epoch):
                        model.train()

                        total_loss = 0
                        for data in train_loader:
                            graph = pyg_to_dgl(data)
                            graph = graph.to(device)
                            data = data.to(device)
                            optimizer.zero_grad()
                            out = model(graph, graph.ndata['feat'], data.batch, data.eigvecs, data.eigvalues)
                            loss = (out.squeeze() - data.y).abs().mean()
                            loss.backward()
                            total_loss += loss.item() * data.num_graphs
                            optimizer.step()
                        return total_loss / len(train_loader.dataset)


                    @torch.no_grad()
                    def test(loader):
                        model.eval()

                        total_error = 0
                        for data in loader:
                            graph = pyg_to_dgl(data)
                            graph = graph.to(device)
                            data = data.to(device)
                            out = model(graph, graph.ndata['feat'], data.batch, data.eigvecs, data.eigvalues)
                            total_error += (out.squeeze() - data.y).abs().sum().item()
                        return total_error / len(loader.dataset)


                    early_stopper = EarlyStopper(patience=args.patience)
                    for epoch in range(1, args.epochs + 1):
                        loss = train(epoch)
                        val_mae = test(val_loader)

                        if val_mae < best_val_mae:
                            best_val_mae = val_mae
                            best_hyperparams = {'lr': lr, 'hidden_dim': hidden_dim, 'n_layers': n_layers,
                                                'grid_size': grid_size, 'spline_order': spline_order}
                            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_mae:.4f}')

                        if early_stopper.early_stop(val_mae):
                            print(f"Stopped at epoch {epoch}")
                            break

print('Best hyperparameters:')
print('lr:', best_hyperparams['lr'])
print('hidden_dim:', best_hyperparams['hidden_dim'])
print('n_layers:', best_hyperparams['n_layers'])
print('grid_size:', best_hyperparams['grid_size'])
print('spline_order:', best_hyperparams['spline_order'])

net_params['GT_hidden_dim'] = best_hyperparams['hidden_dim']
net_params['GT_out_dim'] = best_hyperparams['hidden_dim']
net_params['GT_layers'] = best_hyperparams['n_layers']
net_params['spline_order'] = best_hyperparams['spline_order']
net_params['grid_size'] = best_hyperparams['grid_size']

val_maes = []
test_maes = []
for run in range(5):
    print()
    print(f'Run {run}:')
    print()
    gc.collect()
    model = SAN_NodeLPE(net_params).to(net_params['device'])
    total_params = sum(p.numel() for p in model.parameters())
    print('Number of parameters:', total_params)
    print()
    optimizer = Adam(model.parameters(), lr=0.001)


    def train(epoch):
        model.train()

        total_loss = 0
        for data in train_loader:
            graph = pyg_to_dgl(data)
            graph = graph.to(device)
            data = data.to(device)
            optimizer.zero_grad()
            out = model(graph, graph.ndata['feat'], data.batch, data.eigvecs, data.eigvalues)
            loss = (out.squeeze() - data.y).abs().mean()
            loss.backward()
            total_loss += loss.item() * data.num_graphs
            optimizer.step()
        return total_loss / len(train_loader.dataset)


    @torch.no_grad()
    def test(loader):
        model.eval()

        total_error = 0
        for data in loader:
            graph = pyg_to_dgl(data)
            graph = graph.to(device)
            data = data.to(device)
            out = model(graph, graph.ndata['feat'], data.batch, data.eigvecs, data.eigvalues)
            total_error += (out.squeeze() - data.y).abs().sum().item()
        return total_error / len(loader.dataset)


    best_val_mae = test_mae = float('inf')
    early_stopper = EarlyStopper(patience=args.patience)
    for epoch in range(1, args.epochs + 1):
        loss = train(epoch)
        val_mae = test(val_loader)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            test_mae = test(test_loader)
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
                  f'Val: {val_mae:.4f}, Test: {test_mae:.4f}')

        if early_stopper.early_stop(val_mae):
            print(f"Stopped at epoch {epoch}")
            break

    test_maes.append(test_mae)
    val_maes.append(best_val_mae)

test_mae = torch.tensor(test_maes)
print('===========================')
print(f'Final Test: {test_mae.mean():.4f} ± {test_mae.std():.4f}')

# results
result_statistic = pd.DataFrame(
    columns=['Dataset', 'Model', 'mae', 'std'])

save_dir = os.path.join('..', 'results', 'graph_regression', '{}'.format('ZINC'))
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_path = os.path.join(save_dir, '{}_KAGNNs.xlsx'.format(args.model))
result_statistic.loc[result_statistic.shape[0]] = {'Dataset': 'ZINC',
                                                   'Model': args.model,
                                                   'mae': float(test_mae.mean()), 'std': float(test_mae.std())}
result_statistic.to_excel(save_path)
print('Mission completes.')