import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
import argparse
import time
import os
import random
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_scipy_sparse_matrix


from data_loader import NPYGraphDataset
from models_with_filter import ChebyNetWithFilter


from script import utility

def set_env(seed):
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_parameters():
    
    parser = argparse.ArgumentParser(description='ChebNet Training with Causal Filter')
    parser.add_argument('--enable_cuda', type=bool, default=True, help='enable or disable CUDA, default as True')
    parser.add_argument('--seed', type=int, default=42, help='set the random seed for stabilizing experiment results')
    parser.add_argument('--dataset', type=str, default='molecular', choices=['molecular'], help='dataset name')
    parser.add_argument('--gso_type', type=str, default='sym_norm_lap', 
                        choices=['sym_norm_lap', 'rw_norm_lap'], 
                        help='graph shift operator, default as sym_norm_lap')
    parser.add_argument('--Ko', type=int, default=3, help='K order Chebyshev polynomials')
    parser.add_argument('--Kl', type=int, default=2, help='K layer')
    parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay (L2 penalty)')
    parser.add_argument('--n_hid', type=int, default=64, help='the channel size of hidden layer feature, default as 64')
    parser.add_argument('--enable_bias', type=bool, default=True, help='default as True')
    parser.add_argument('--droprate', type=float, default=0.5, help='dropout rate, default as 0.5')
    parser.add_argument('--epochs', type=int, default=100, help='epochs, default as 100')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size for training')
    parser.add_argument('--opt', type=str, default='adam', help='optimizer, default as adam')
    parser.add_argument('--use_causal_filter', action='store_true', default=True, help='Enable causal filter')
    
    args = parser.parse_args()
    print('Training configs: {}'.format(args))

    set_env(args.seed)

    device = torch.device('cuda' if args.enable_cuda and torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}, Causal filter: {'Enabled' if args.use_causal_filter else 'Disabled'}")

    return device, args

def process_graph_data(data, device, gso_type):
    
    num_nodes = data.x.size(0)
    adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
    gso = utility.calc_gso(adj, gso_type)
    gso = utility.calc_chebynet_gso(gso)
    return utility.cnv_sparse_mat_to_coo_tensor(gso, device)

def train_epoch(model, train_loader, optimizer, device, gso_type):
    
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        
        batch_loss = 0
        
        
        for data_i in batch.to_data_list():
            gso_i = process_graph_data(data_i, device, gso_type)
            out_i = model(data_i.x, gso_i) 
            
            
            graph_out_i = torch.mean(out_i, dim=0, keepdim=True)
            
            loss = F.nll_loss(graph_out_i, data_i.y)
            batch_loss += loss

            pred = graph_out_i.argmax(dim=1)
            correct += (pred == data_i.y).sum().item()

        if batch.num_graphs > 0:
            batch_loss = batch_loss / batch.num_graphs
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()
            total += batch.num_graphs
    
    return total_loss / len(train_loader), correct / total

def evaluate(model, loader, device, gso_type):
    
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            
            for data_i in batch.to_data_list():
                gso_i = process_graph_data(data_i, device, gso_type)
                out_i = model(data_i.x, gso_i)
                graph_out_i = torch.mean(out_i, dim=0, keepdim=True)
                pred = graph_out_i.argmax(dim=1)
                correct += (pred == data_i.y).sum().item()

            total += batch.num_graphs
    
    return correct / total

def main():
    for i in ['crcg']:
        print(f'\nRunning training for dataset: {i}')
        device, args = get_parameters()
        
        train_dataset = NPYGraphDataset(f'./data/molecular/train_{i}.npy', task='graph')
        val_dataset   = NPYGraphDataset(f'./data/molecular/val_{i}.npy', task='graph')
        test_dataset  = NPYGraphDataset(f'./data/molecular/test_{i}.npy', task='graph')
        
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        val_loader   = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
        test_loader  = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
        
        n_feat = train_dataset.num_features
        n_class = int(train_dataset.data.y.max()) + 1
        
        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99},
        }

        model = ChebyNetWithFilter(
            n_feat, args.n_hid, n_class, args.enable_bias, args.Ko, args.Kl, args.droprate,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='graph'
        ).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
        best_val_acc = 0
        final_train_accuracies = []
        final_test_accuracies = []
        
        for epoch in range(1, args.epochs + 1):
            train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, args.gso_type)
            val_acc = evaluate(model, val_loader, device, args.gso_type)
            test_acc = evaluate(model, test_loader, device, args.gso_type)
            
            if args.use_causal_filter:
                model.step_epoch()

            final_train_accuracies.append(train_acc * 100)
            final_test_accuracies.append(test_acc * 100)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                
            
            if epoch % 10 == 0:
                print(f'Epoch: {epoch:03d} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}')
                if args.use_causal_filter:
                    print(model.get_filter_info())

        final_train_mean = np.mean(final_train_accuracies)
        final_train_std = np.std(final_train_accuracies)
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)
        
        print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  |  Test Acc  {final_test_mean:.2f} ± {final_test_std:.2f}')

if __name__ == '__main__':
    main()