import sys
sys.path.append('./')
from improved_filter import ImprovedCausalFilter
import copy
import torch
import argparse
import os.path as osp
from torch_geometric.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LEConv, BatchNorm
from utils.mask import set_masks, clear_masks
import time
import os
import numpy as np
from torch.autograd import grad
from utils.logger import Logger
from datetime import datetime
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, split_batch, relabel
from gnn import SPMotifNet
from molecular_dataset import MolecularDataset

class CausalAttNet(nn.Module):
    def __init__(self, causal_ratio, in_channels, channels):
        super(CausalAttNet, self).__init__()
        self.conv1 = LEConv(in_channels=in_channels, out_channels=channels)
        self.conv2 = LEConv(in_channels=channels, out_channels=channels)
        self.mlp = nn.Sequential(
            nn.Linear(channels*2, channels*4),
            nn.ReLU(),
            nn.Linear(channels*4, 1))
        self.ratio = causal_ratio  
    
    def forward(self, data):
        
        x = F.relu(self.conv1(data.x, data.edge_index, data.edge_attr.view(-1)))
        x = self.conv2(x, data.edge_index, data.edge_attr.view(-1))
        
        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        (causal_edge_index, causal_edge_attr, causal_edge_weight), \
        (conf_edge_index, conf_edge_attr, conf_edge_weight) = split_graph(data, edge_score, self.ratio)

        causal_x, causal_edge_index, causal_batch, _ = relabel(x, causal_edge_index, data.batch)
        conf_x, conf_edge_index, conf_batch, _ = relabel(x, conf_edge_index, data.batch)

        return (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch),\
                (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch),\
                edge_score

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Training for Causal Feature Learning on Molecular Dataset')
    parser.add_argument('--cuda', default=0, type=int, help='cuda device')
    parser.add_argument('--datadir', default='mole/', type=str, help='directory for datasets.')
    parser.add_argument('--epoch', default=50, type=int, help='training iterations')
    parser.add_argument('--reg', default=1, type=int)
    parser.add_argument('--seed',  nargs='?', default='[1,2,3]', help='random seed')
    parser.add_argument('--channels', default=32, type=int, help='width of network')
    parser.add_argument('--commit', default='', type=str, help='experiment name')
    
    parser.add_argument('--pretrain', default=10, type=int, help='pretrain epoch')
    parser.add_argument('--alpha', default=1e-2, type=float, help='invariant loss')
    parser.add_argument('--r', default=0.25, type=float, help='causal_ratio')
    
    parser.add_argument('--batch_size', default=64, type=int, help='batch size')
    parser.add_argument('--net_lr', default=1e-3, type=float, help='learning rate for the predictor')
    
    parser.add_argument('--use_causal_filter', default=True, action='store_true')
    parser.add_argument('--filter_lambda_init', type=float, default=1.0)
    parser.add_argument('--filter_lambda_min', type=float, default=-2.0)
    parser.add_argument('--filter_decay', type=float, default=0.99)
    parser.add_argument('--filter_temp', type=float, default=1.0)
    parser.add_argument('--filter_residual', type=float, default=0.2)
    args = parser.parse_args()
    args.seed = eval(args.seed)
    
    
    device = torch.device('cuda:%d' % args.cuda if torch.cuda.is_available() else 'cpu') 
    
    
    train_dataset = MolecularDataset(osp.join(args.datadir), mode='train')
    val_dataset = MolecularDataset(osp.join(args.datadir), mode='val')
    test_dataset = MolecularDataset(osp.join(args.datadir), mode='test')
    
    
    num_features = train_dataset.data.x.size(1)
    num_classes = int(train_dataset.data.y.max().item()) + 1
    
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    
    n_train_data, n_val_data = len(train_dataset), len(val_dataset)
    n_test_data = float(len(test_dataset))

    
    datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
    all_info = {'causal_acc':[], 'conf_acc':[], 'train_acc':[], 'val_acc':[], 'test_prec':[], 'train_prec':[], 'test_mrr':[], 'train_mrr':[]}
    experiment_name = f'CRCG-Molecular.{bool(args.reg)}.{args.commit}.netlr_{args.net_lr}.batch_{args.batch_size}'\
                      f'.channels_{args.channels}.pretrain_{args.pretrain}.r_{args.r}.alpha_{args.alpha}.seed_{args.seed}.{datetime_now}'
    exp_dir = osp.join('local/', experiment_name)
    os.makedirs(exp_dir, exist_ok=True)
    logger = Logger.init_logger(filename=exp_dir + '/_output_.log')
    args_print(args, logger)
    
    seed = 1
    set_seed(seed)
    
    g = SPMotifNet(args.channels, num_classes=num_classes).to(device)
    att_net = CausalAttNet(args.r, num_features, args.channels).to(device)
    feature_filter = None
    if args.use_causal_filter:
        feature_filter = ImprovedCausalFilter(
            input_dim=num_features,
            lambda_init=args.filter_lambda_init,
            lambda_min=args.filter_lambda_min,
            decay_rate=args.filter_decay,
            temperature=args.filter_temp,
            residual_weight=args.filter_residual,
            normalize=False,
            dropout=0.1
        ).to(device)
    params = list(g.parameters()) + list(att_net.parameters()) + (list(feature_filter.parameters()) if feature_filter else [])
    model_optimizer = torch.optim.Adam(params, lr=args.net_lr)
    conf_opt = torch.optim.Adam(g.conf_fw.parameters(), lr=args.net_lr)
    
    CELoss = nn.CrossEntropyLoss(reduction="mean")
    EleCELoss = nn.CrossEntropyLoss(reduction="none")
    
    def train_mode():
        g.train(); att_net.train()
        
    def val_mode():
        g.eval(); att_net.eval()

    def test_metrics(loader, att_net):
        def metrics_batch(graph, pred_weight, mrr_k=5):
            _, _, _, num_edges, cum_edges = split_batch(graph)
            ground_truth_mask = graph.edge_gt_att.view(-1)               
            precision, mrr = [], []
            for E, C in zip(num_edges.tolist(), cum_edges.tolist()):
                
                num_gd = int(ground_truth_mask[C: C + E].sum())
                pred = pred_weight[C:C + E]
                _, indices_for_sort = pred.sort(descending=True, dim=-1)
                idx = indices_for_sort[:num_gd].detach().cpu().numpy()
                if num_gd > 0:
                    precision_value = ground_truth_mask[C: C + E][idx].sum().float() / num_gd
                else:
                    precision_value = torch.tensor(0.0)
                precision.append(precision_value)
                
                k = min(pred.size(0), mrr_k)
                true_sorted_by_preds = torch.gather(
                    graph.edge_gt_att[C: C + E], dim=-1, index=indices_for_sort
                )
                true_sorted_by_pred_shrink = true_sorted_by_preds[:k]
                values, indices = torch.max(true_sorted_by_pred_shrink, dim=0)
                indices = indices.type_as(values).unsqueeze(dim=0).t()
                result = torch.tensor(1.0) / (indices + torch.tensor(1.0))
                zero_sum_mask = values == 0.0
                result[zero_sum_mask] = 0.0
                mrr.append(result[0])
            return torch.tensor(precision), torch.tensor(mrr)

        precision_lst, mrr_lst =  torch.FloatTensor([]), torch.FloatTensor([])
        for graph in loader: 
            graph.to(device)
            causal_g, conf_g, edge_score = att_net(graph)
            precision, mrr = metrics_batch(graph, edge_score)
            precision_lst = torch.cat([precision_lst, precision])
            mrr_lst = torch.cat([mrr_lst, mrr])
        return torch.mean(precision_lst), torch.mean(mrr_lst)           
    
    def test_acc(loader, att_net, predictor):
        acc = 0
        for graph in loader: 
            graph.to(device)
            if feature_filter:
                causal_input = feature_filter(graph.x)
                conf_input = causal_input  
            else:
                causal_input = graph.x
                conf_input = graph.x
            (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch),\
            (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch), edge_score = att_net(graph)
            
            if causal_x.size(0) == causal_input.size(0):
                causal_x = causal_input
                conf_x = conf_input
            set_masks(causal_edge_weight, g)
            out = predictor(x=causal_x, edge_index=causal_edge_index, 
                    edge_attr=causal_edge_attr, batch=causal_batch)
            clear_masks(g)
            acc += torch.sum(out.argmax(-1).view(-1) == graph.y.view(-1))
        acc = float(acc) / len(loader.dataset)
        return acc

    logger.info(f"# Train: {n_train_data}  #Test: {n_test_data} #Val: {n_val_data}")
    cnt, last_val_acc = 0, 0
    
    val_acc_list_5_epoch = []
    final_train_accuracies = []
    final_in_test_accuracies = []
    final_test_accuracies = []

    
    for epoch in range(0, 50):
        causal_edge_weights = torch.tensor([]).to(device)
        conf_edge_weights = torch.tensor([]).to(device)
        reg = args.reg
        alpha_prime = args.alpha * (epoch ** 1.6)
        all_loss, n_bw, all_env_loss = 0, 0, 0
        all_causal_loss, all_conf_loss, all_var_loss = 0, 0, 0
        dummy_w = nn.Parameter(torch.Tensor([1.0])).to(device)
        
        train_mode()
        for graph in train_loader:
            n_bw += 1
            graph.to(device)
            N = graph.num_graphs
            
            
            (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch),\
            (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch), edge_score = att_net(graph)
            
            
            set_masks(causal_edge_weight, g)
            out = g(x=causal_x, edge_index=causal_edge_index, 
                    edge_attr=causal_edge_attr, batch=causal_batch)
            causal_loss = CELoss(out, graph.y.view(-1))
            
            
            clear_masks(g)
            conf_out = g.forward(x=conf_x, edge_index=conf_edge_index, 
                                edge_attr=conf_edge_attr, batch=conf_batch)
            conf_loss = CELoss(conf_out, graph.y.view(-1))
            
            
            loss = causal_loss + conf_loss
            
            
            model_optimizer.zero_grad()
            loss.backward()
            model_optimizer.step()
            
            all_loss += loss.item()
            all_causal_loss += causal_loss.item()
            all_conf_loss += conf_loss.item()
        
        
        val_mode()
        val_acc = test_acc(val_loader, att_net, g)
        if feature_filter: feature_filter.step()
        
        
        train_acc = test_acc(train_loader, att_net, g)
        final_train_accuracies.append(train_acc * 100)  
        
        
        final_in_test_accuracies.append(val_acc * 100)  
        final_test_accuracies.append(val_acc * 100)  
        
        logger.info(f"Epoch: {epoch:03d}, Loss: {all_loss/n_bw:.4f}, "
                    f"Causal: {all_causal_loss/n_bw:.4f}, Conf: {all_conf_loss/n_bw:.4f}, "
                    f"Train: {train_acc * 100:.4f}%, Val: {val_acc * 100:.4f}%")

        
        if (epoch + 1) % 5 == 0:
            std_val = np.std(final_in_test_accuracies)
            
        
        
        if val_acc > last_val_acc:
            last_val_acc = val_acc
            cnt = 0
            best_g, best_att = copy.deepcopy(g), copy.deepcopy(att_net)
        else:
            cnt += 1
        
        
        if cnt >= 20:
            break

    
    final_train_mean = np.mean(final_train_accuracies)
    final_train_std = np.std(final_train_accuracies)
    final_test_mean = np.mean(final_test_accuracies)
    final_test_std = np.std(final_test_accuracies)
    final_in_test_mean = np.mean(final_in_test_accuracies)
    final_in_test_std = np.std(final_in_test_accuracies)

    
    
    
    g, att_net = best_g, best_att
    val_mode()
    test_acc_value = test_acc(test_loader, att_net, g)
    print(f"Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {test_acc_value * 100:.2f} ± {final_test_std:.2f}")
    if feature_filter:
        torch.save(feature_filter.state_dict(), os.path.join(exp_dir, 'best_filter.pth'))

