import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.utils import to_scipy_sparse_matrix


from models_with_filter import ChebyNetWithFilter
from script import utility


class PaperNodeDataset(InMemoryDataset):
    def __init__(self, npy_path, transform=None):
        super(PaperNodeDataset, self).__init__(None, transform)
        self.data, self.slices = self.load_npy(npy_path)
        
        self._num_features = self.data.x.size(1)

    @property
    def num_features(self):
        return self._num_features

    def load_npy(self, npy_path):
        import torch
        from torch_geometric.data import Data

        raw = np.load(npy_path, allow_pickle=True)
        data_dict = raw.item() if hasattr(raw, 'item') else raw

        edge_indices  = data_dict['edge_index']
        features_list = data_dict['features']
        role_ids_list = data_dict['role_id']    

        data_list = []
        for edges, features, role_ids in zip(edge_indices,
                                             features_list,
                                             role_ids_list):
            
            e = torch.tensor(edges, dtype=torch.long)
            if e.ndim == 2 and e.size(0) != 2:
                e = e.t().contiguous()
            e = e.contiguous()

            x = torch.tensor(features, dtype=torch.float)        
            y = torch.tensor(role_ids, dtype=torch.long)        

            data_list.append(Data(x=x, edge_index=e, y=y))

        return self.collate(data_list)


def process_graph_data_for_chebnet(data, device, gso_type='sym_norm_lap'):
    
    num_nodes = data.x.size(0)
    adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
    gso = utility.calc_gso(adj, gso_type)
    gso = utility.calc_chebynet_gso(gso)
    return utility.cnv_sparse_mat_to_coo_tensor(gso, device)


def train(model, loader, optimizer, device, gso_type='sym_norm_lap'):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        
        batch_loss = 0
        
        
        for single_data in data.to_data_list():
            gso = process_graph_data_for_chebnet(single_data, device, gso_type)
            out = model(single_data.x, gso)
            loss = F.nll_loss(out, single_data.y)
            batch_loss += loss
            
            pred = out.argmax(dim=1)
            total_correct += (pred == single_data.y).sum().item()
            total_samples += single_data.y.size(0)

        if data.num_graphs > 0:
            batch_loss = batch_loss / data.num_graphs
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()

    return total_loss / len(loader), total_correct / total_samples


def evaluate(model, loader, device, gso_type='sym_norm_lap'):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            for single_data in data.to_data_list():
                gso = process_graph_data_for_chebnet(single_data, device, gso_type)
                out = model(single_data.x, gso)
                pred = out.argmax(dim=1)
                total_correct += (pred == single_data.y).sum().item()
                total_samples += single_data.y.size(0)
    
    return total_correct / total_samples


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--Ko', type=int, default=3, help='K order Chebyshev polynomials')
    parser.add_argument('--Kl', type=int, default=2, help='K layer')
    parser.add_argument('--gso_type', type=str, default='sym_norm_lap', choices=['sym_norm_lap', 'rw_norm_lap'])
    parser.add_argument('--enable_bias', type=bool, default=True)
    parser.add_argument('--droprate', type=float, default=0.5)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--use_causal_filter', action='store_true', default=False)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f"Using device: {device}, Causal filter: {'Enabled' if args.use_causal_filter else 'Disabled'}")

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    for i in ['casual_1_3']:
        print(f'\nProcessing dataset split: {i}')
        
        train_dataset = PaperNodeDataset(f'./data/paper/train_{i}.npy')
        val_dataset   = PaperNodeDataset(f'./data/paper/val_{i}.npy')
        test_dataset  = PaperNodeDataset(f'./data/paper/test_{i}.npy')

        train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
        val_loader   = DataLoader(val_dataset, batch_size=128, shuffle=False)
        test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)

        num_features = train_dataset.num_features
        num_classes  = int(train_dataset.data.y.max().item()) + 1

        filter_config = {
            'input': {'lambda_init': 1.0, 'decay_rate': 0.99},
            'hidden': {'lambda_init': 1.0, 'decay_rate': 0.99},
        }

        model = ChebyNetWithFilter(
            n_feat=num_features,
            n_hid=args.hidden_dim,
            n_class=num_classes,
            enable_bias=args.enable_bias,
            Ko=args.Ko,
            Kl=args.Kl,
            droprate=args.droprate,
            use_causal_filter=args.use_causal_filter,
            filter_config=filter_config if args.use_causal_filter else None,
            task='node'
        ).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        best_val_acc = 0
        final_train_accuracies = []
        final_test_accuracies = []

        for epoch in range(1, args.epochs + 1):
            loss, train_acc = train(model, train_loader, optimizer, device, args.gso_type)
            val_acc = evaluate(model, val_loader, device, args.gso_type)
            test_acc = evaluate(model, test_loader, device, args.gso_type)
            
            if args.use_causal_filter:
                model.step_epoch()

            final_train_accuracies.append(train_acc * 100)
            final_test_accuracies.append(test_acc * 100)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch:03d}: Loss {loss:.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}')
                if args.use_causal_filter:
                    print(model.get_filter_info())

        final_train_mean = np.mean(final_train_accuracies)
        final_train_std = np.std(final_train_accuracies)
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)

        print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  |  Test Acc  {final_test_mean:.2f} ± {final_test_std:.2f}')