import sys
sys.path.append('./')
from improved_filter import ImprovedCausalFilter
import torch
import argparse
import os
import os.path as osp
import numpy as np
from datetime import datetime
from torch_geometric.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LEConv
from utils.mask import set_masks, clear_masks
from utils.logger import Logger
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, relabel
from gnn import SPMotifNet1
from paper_dataset import PaperDataset


class CausalAttNet(nn.Module):
    
    def __init__(self, causal_ratio, in_channels, channels):
        super().__init__()
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels,    out_channels=channels)
        self.mlp   = nn.Sequential(
            nn.Linear(channels*2, channels*4),
            nn.ReLU(),
            nn.Linear(channels*4, 1)
        )
        self.ratio = causal_ratio

    def forward(self, data):
        x0 = data.x
        h = F.relu(self.conv1(x0, data.edge_index, data.edge_attr.view(-1)))
        h = self.conv2(h, data.edge_index, data.edge_attr.view(-1))
        row, col = data.edge_index
        rep = torch.cat([h[row], h[col]], dim=-1)
        score = self.mlp(rep).view(-1)
        (cei, cea, cew), (fei, fea, few) = split_graph(data, score, self.ratio)
        return (cei, cea, cew), (fei, fea, few), score


def main():
    parser = argparse.ArgumentParser('CRCG Paper Node Classification')
    parser.add_argument('--cuda',      default=0,           type=int)
    parser.add_argument('--datadir',   default='paper',     type=str)
    parser.add_argument('--epoch',     default=50,         type=int)
    parser.add_argument('--r',         default=0.7,        type=float)
    parser.add_argument('--batch_size',default=128,           type=int)
    parser.add_argument('--net_lr',    default=0.01,        type=float)
    parser.add_argument('--seed',      nargs='?', default='[42]')
    parser.add_argument('--use_causal_filter', action='store_true')
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()
    args.seed = eval(args.seed)

    device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')

    
    train_ds = PaperDataset(osp.join(args.datadir), mode='train')
    val_ds   = PaperDataset(osp.join(args.datadir), mode='val')
    test_ds  = PaperDataset(osp.join(args.datadir), mode='test')
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds,  batch_size=args.batch_size, shuffle=False)

    
    num_feat    = train_ds.data.x.size(1)
    num_classes = int(train_ds.data.y.max().item()) + 1
    g           = SPMotifNet1(in_channels=num_feat, num_classes=num_classes).to(device)
    hidden_ch   = g.node_emb.out_features
    att_net     = CausalAttNet(args.r, in_channels=num_feat, channels=hidden_ch).to(device)

    feature_filter = None
    if args.use_causal_filter:
        feature_filter = ImprovedCausalFilter(
            input_dim=num_feat,
            lambda_init=args.filter_lambda_init,
            lambda_min=args.filter_lambda_min,
            decay_rate=args.filter_decay,
            temperature=args.filter_temp,
            residual_weight=args.filter_residual,
            normalize=False,
            dropout=0.1
        ).to(device)

    optimizer = torch.optim.Adam(list(g.parameters()) + list(att_net.parameters()) + (list(feature_filter.parameters()) if feature_filter else []), lr=args.net_lr)
    criterion = nn.CrossEntropyLoss()

    
    exp_name = f'CRCG-Paper.r{args.r}.lr{args.net_lr}.bs{args.batch_size}.' + datetime.now().strftime('%Y%m%d-%H%M%S')
    log_dir  = osp.join('logs', exp_name)
    os.makedirs(log_dir, exist_ok=True)
    logger   = Logger.init_logger(filename=osp.join(log_dir, 'train.log'))
    args_print(args, logger)

    best_val_acc = 0
    stale = 0
    
    
    final_train_accuracies = []
    final_in_test_accuracies = []
    final_test_accuracies = []

    for seed in args.seed:
        set_seed(seed)
        val_acc_list_5_epoch = []
        for epoch in range(0, args.epoch):
            g.train(); att_net.train()
            total_loss = 0.0
            for data in train_loader:
                data = data.to(device)
                (cei, cea, cew), (fei, fea, few), _ = att_net(data if feature_filter is None else type(data)(x=feature_filter(data.x), edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y, batch=data.batch))
                
                set_masks(cew, g)
                out_c = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch)
                loss_c = criterion(out_c, data.y.view(-1))
                
                clear_masks(g)
                out_f = g(x=data.x, edge_index=fei, edge_attr=fea, batch=data.batch)
                loss_f = criterion(out_f, data.y.view(-1))
                loss = loss_c + loss_f
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            
            g.eval(); att_net.eval()
            correct = 0; total = 0
            with torch.no_grad():
                for data in val_loader:
                    data = data.to(device)
                    (cei, cea, cew), _, _ = att_net(data if feature_filter is None else type(data)(x=feature_filter(data.x), edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y, batch=data.batch))
                    set_masks(cew, g)
                    out_v = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch)
                    pred = out_v.argmax(dim=1)
                    clear_masks(g)
                    correct += (pred == data.y.view(-1)).sum().item()
                    total   += data.y.numel()
            val_acc = correct / total
            val_acc_list_5_epoch.append(val_acc)

            
            train_correct = 0; train_total = 0
            with torch.no_grad():
                for data in train_loader:
                    data = data.to(device)
                    (cei, cea, cew), _, _ = att_net(data if feature_filter is None else type(data)(x=feature_filter(data.x), edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y, batch=data.batch))
                    set_masks(cew, g)
                    out_train = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch)
                    pred_train = out_train.argmax(dim=1)
                    clear_masks(g)
                    train_correct += (pred_train == data.y.view(-1)).sum().item()
                    train_total += data.y.numel()
            train_acc = train_correct / train_total

            
            final_train_accuracies.append(train_acc * 100)  
            final_in_test_accuracies.append(val_acc * 100)  
            final_test_accuracies.append(val_acc * 100)  

            logger.info(f'Epoch {epoch:03d} | Loss {total_loss/len(train_loader):.4f} | Train Acc {train_acc:.4f} | Val Acc {val_acc:.4f}')
            if (epoch + 1) % 5 == 0:
                std_val = np.std(val_acc_list_5_epoch)
                
                val_acc_list_5_epoch = []

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                stale = 0
                torch.save({'g': g.state_dict(), 'att': att_net.state_dict()}, osp.join(log_dir, 'best.pth'))
            else:
                stale += 1
            if stale >= 20:
                break
            if feature_filter: feature_filter.step()

    
    final_train_mean = np.mean(final_train_accuracies)
    final_train_std = np.std(final_train_accuracies)
    final_test_mean = np.mean(final_test_accuracies)
    final_test_std = np.std(final_test_accuracies)
    final_in_test_mean = np.mean(final_in_test_accuracies)
    final_in_test_std = np.std(final_in_test_accuracies)

    logger.info(f'Final Train Acc: {final_train_mean:.2f} ± {final_train_std:.2f}')
    
    
    ckpt = torch.load(osp.join(log_dir, 'best.pth'))
    g.load_state_dict(ckpt['g']); att_net.load_state_dict(ckpt['att'])
    g.eval(); att_net.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            (cei, cea, cew), _, _ = att_net(data if feature_filter is None else type(data)(x=feature_filter(data.x), edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y, batch=data.batch))
            set_masks(cew, g)
            out_t = g(x=data.x, edge_index=cei, edge_attr=cea, batch=data.batch)
            pred = out_t.argmax(dim=1)
            clear_masks(g)
            correct += (pred == data.y.view(-1)).sum().item()
            total   += data.y.numel()
    logger.info(f'Test Acc: {correct/total*100:.2f} ± {final_test_std:.2f}')
    print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {correct/total*100:.2f} ± {final_test_std:.2f}')

    if feature_filter:
        torch.save(feature_filter.state_dict(), osp.join(log_dir, 'best_filter.pth'))

if __name__ == '__main__':
    main()
