import sys
sys.path.append('./')
import copy
import torch
import argparse
from datasets import SPMotif
from torch_geometric.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import LEConv, BatchNorm, fps
from utils.mask import set_masks, clear_masks
import os
import numpy as np
import os.path as osp
from torch.autograd import grad
from utils.logger import Logger
from datetime import datetime
from utils.helper import set_seed, args_print
from utils.get_subgraph import split_graph, split_batch, relabel
from gnn import SPMotifNet,SPMotifNet_LEGNN
class CausalAttNet(nn.Module):
    def __init__(self, causal_ratio):
        super(CausalAttNet, self).__init__()
        self.conv1 = LEConv(in_channels=5, out_channels=args.channels)
        self.conv2 = LEConv(in_channels=args.channels, out_channels=args.channels)
        self.mlp = nn.Sequential(
            nn.Linear(args.channels*2, args.channels*4),
            nn.ReLU(),
            nn.Linear(args.channels*4, 1)
        )
        self.ratio = causal_ratio
    def forward(self, data):
        
        x = F.relu(self.conv1(data.x, data.edge_index, data.edge_attr.view(-1)))
        x = self.conv2(x, data.edge_index, data.edge_attr.view(-1))

        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)
        edge_score = self.mlp(edge_rep).view(-1)

        (causal_edge_index, causal_edge_attr, causal_edge_weight), \
        (conf_edge_index, conf_edge_attr, conf_edge_weight) = split_graph(data,edge_score, self.ratio)

        causal_x, causal_edge_index, causal_batch, _ = relabel(x, causal_edge_index, data.batch)
        conf_x, conf_edge_index, conf_batch, _ = relabel(x, conf_edge_index, data.batch)

        return (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch),\
                (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch),\
                edge_score
if __name__ == "__main__": 
    parser = argparse.ArgumentParser(description='Training for Causal Feature Learning')
    parser.add_argument('--cuda', default=0, type=int, help='cuda device')
    parser.add_argument('--datadir', default='data/', type=str, help='directory for datasets.')
    parser.add_argument('--epoch', default=400, type=int, help='training iterations')
    parser.add_argument('--reg', default=1, type=int)
    parser.add_argument('--seed',  nargs='?', default='[1]', help='random seed')
    parser.add_argument('--channels', default=32, type=int, help='width of network')
    parser.add_argument('--commit', default='', type=str, help='experiment name')
    parser.add_argument('--bias', default='0.333', type=str, help='select bias extend')
    
    parser.add_argument('--pretrain', default=10, type=int, help='pretrain epoch')
    parser.add_argument('--alpha', default=1e-2, type=float, help='invariant loss')
    parser.add_argument('--r', default=0.25, type=float, help='causal_ratio')
    
    parser.add_argument('--batch_size', default=32, type=int, help='batch size')
    parser.add_argument('--net_lr', default=1e-3, type=float, help='learning rate for the predictor')
    args = parser.parse_args()
    args.seed = eval(args.seed)
    
    num_classes = 5
    device = torch.device('cuda:%d' % args.cuda if torch.cuda.is_available() else 'cpu') 
    train_dataset = SPMotif(osp.join(args.datadir, f'CRCG-MOTIF/'), mode='train')
    val_dataset = SPMotif(osp.join(args.datadir, f'CRCG-MOTIF/'), mode='val')
    test_dataset = SPMotif(osp.join(args.datadir, f'CRCG-MOTIF/'), mode='test')
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    n_train_data, n_val_data = len(train_dataset), len(val_dataset)
    n_test_data = float(len(test_dataset))

    
    datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
    all_info = { 'causal_acc':[], 'conf_acc':[], 'train_acc':[], 'val_acc':[], 'test_prec':[], 'train_prec':[], 'test_mrr':[], 'train_mrr':[]}
    experiment_name = f'CRCG-MOTIF.{bool(args.reg)}.{args.commit}.netlr_{args.net_lr}.batch_{args.batch_size}.channels_{args.channels}.pretrain_{args.pretrain}.r_{args.r}.alpha_{args.alpha}.seed_{args.seed}.{datetime_now}'
    exp_dir = osp.join('local/', experiment_name)
    os.makedirs(exp_dir, exist_ok=True)
    logger = Logger.init_logger(filename=exp_dir + '/_output_.log')
    args_print(args, logger)

    for seed in args.seed:
        set_seed(seed)
        
        
        legnn = SPMotifNet_LEGNN(5,num_classes=5).to(device)
        att_net = CausalAttNet(args.r).to(device)
        model_optimizer = torch.optim.Adam(
            list(legnn.parameters()),
            lr=args.net_lr)
        
        CELoss = nn.CrossEntropyLoss(reduction="mean")
        EleCELoss = nn.CrossEntropyLoss(reduction="none")

        def train_mode():
            legnn.train()
            
        def val_mode():
            legnn.eval()

        def test_metrics(loader, att_net):
            def metrics_batch(graph, pred_weight, mrr_k=5):
                _, _, _, num_edges, cum_edges = split_batch(graph)
                ground_truth_mask = graph.edge_gt_att.view(-1)
                precision, mrr = [], []
                for E, C in zip(num_edges.tolist(), cum_edges.tolist()):
                    
                    num_gd = int(ground_truth_mask[C: C + E].sum())
                    pred = pred_weight[C:C + E]
                    _, indices_for_sort = pred.sort(descending=True, dim=-1)
                    idx = indices_for_sort[:num_gd].detach().cpu().numpy()
                    precision.append(ground_truth_mask[C: C + E][idx].sum().float()/num_gd)

                    
                    k = min(pred.size(0), mrr_k)
                    true_sorted_by_preds = torch.gather(
                        graph.edge_gt_att[C: C + E], dim=-1, index=indices_for_sort
                    )
                    true_sorted_by_pred_shrink = true_sorted_by_preds[:k]
                    values, indices = torch.max(true_sorted_by_pred_shrink, dim=0)
                    indices = indices.type_as(values).unsqueeze(dim=0).t()
                    result = torch.tensor(1.0) / (indices + torch.tensor(1.0))
                    zero_sum_mask = values == 0.0
                    result[zero_sum_mask] = 0.0
                    mrr.append(result[0])
                return torch.tensor(precision), torch.tensor(mrr)

            precision_lst, mrr_lst =  torch.FloatTensor([]), torch.FloatTensor([])
            for graph in loader:
                graph.to(device)
                causal_g, conf_g, edge_score = att_net(graph)
                precision, mrr = metrics_batch(graph, edge_score)
                precision_lst = torch.cat([precision_lst, precision])
                mrr_lst = torch.cat([mrr_lst, mrr])
            return torch.mean(precision_lst), torch.mean(mrr_lst)
            
        def test_acc(loader, predictor):
            acc = 0
            for graph in loader: 
                graph.to(device)
                
                
                

                out = predictor(graph.x, graph.edge_index,
                        graph.edge_attr, graph.batch)
                acc += torch.sum(out.argmax(-1).view(-1) == graph.y.view(-1))
            acc = float(acc) / len(loader.dataset)
            return acc
        logger.info(f"# Train: {n_train_data}  #Test: {n_test_data} #Val: {n_val_data}")
        cnt, last_val_acc = 0, 0
        for epoch in range(args.epoch):                
            causal_edge_weights = torch.tensor([]).to(device)
            conf_edge_weights = torch.tensor([]).to(device)
            reg = args.reg
            alpha_prime = args.alpha * (epoch ** 1.6)
            all_loss, n_bw, all_env_loss = 0, 0, 0
            all_causal_loss, all_conf_loss, all_var_loss = 0, 0, 0
            dummy_w = nn.Parameter(torch.Tensor([1.0])).to(device)
            train_mode()
            for graph in train_loader:
                n_bw += 1
                graph.to(device)
                N = graph.num_graphs
                rep = legnn.get_graph_rep(
                    x=graph.x, edge_index=graph.edge_index,
                    edge_attr=graph.edge_attr, batch=graph.batch)
                out = legnn.get_robust_pred(rep)
                
                Hi =legnn.get_node_reps(graph.x, graph.edge_index, graph.edge_attr, graph.batch)                
                
                Lp = 0
                Lp_count = 0
                labels = graph.y
                out = legnn(x=graph.x, edge_index=graph.edge_index,
                    edge_attr=graph.edge_attr, batch=graph.batch)
                predicted_labels = out.argmax(dim=1)
                is_correct = predicted_labels == graph.y
                for c in range(num_classes):
                   
                   Hc_positive_list = [Hi[i].unsqueeze(0) for i, label in enumerate(labels) if label == c and is_correct[i]]
                   Hc_positive = torch.cat(Hc_positive_list, dim=0) if len(Hc_positive_list) > 0 else torch.empty(0)
                   
                   
                   Hc_negative_list = [Hi[i].unsqueeze(0) for i, label in enumerate(labels) if label != c and not is_correct[i]]
                   Hc_negative = torch.cat(Hc_negative_list, dim=0) if len(Hc_negative_list) > 0 else torch.empty(0)
                   
                   if Hc_positive.dim() == 2 and Hc_negative.dim() == 2:
                      if Hc_positive.size(1) == Hc_negative.size(1):
                          similarity_matrix = torch.mm(Hc_positive, Hc_negative.t())
                      else:
                           similarity_matrix = torch.empty((Hc_positive.size(0), Hc_negative.size(0)))
                   else:
                        similarity_matrix = torch.empty((Hc_positive.size(0), Hc_negative.size(0)))
                   
                   valid_indices = torch.nonzero(similarity_matrix > 0.9) 
                   valid_indices = valid_indices.transpose(0, 1)
                   
                   for i, label in enumerate(labels):                                                                                                                                                                              
                      if label == c: 
                         
                          Xz = torch.mean(Hi[valid_indices[0][valid_indices[1] == i]], dim=0)
                          Xz = Xz.detach()
                          
                          Xa = torch.mean(Hi[i], dim=0)
                          
                          similarity = torch.cosine_similarity(Xz.unsqueeze(0), Xa.unsqueeze(0))
                          
                          Lp -= similarity
                          Lp_count += 1
                  
                if Lp_count > 0:
                   Lp /= Lp_count               
                
                Lp = 0
                Lp_count = 0
                Ln = 0
                Ln_count = 0
                
                for c in range(num_classes):
                   Hc_positive_list = [Hi[i].unsqueeze(0) for i, label in enumerate(labels) if label == c and is_correct[i]]
                   Hc_positive = torch.cat(Hc_positive_list, dim=0) if len(Hc_positive_list) > 0 else torch.empty(0)
                   
                   Hc_negative_list = [Hi[i].unsqueeze(0) for i, label in enumerate(labels) if label == c and not is_correct[i]]
                   Hc_negative = torch.cat(Hc_negative_list, dim=0) if len(Hc_negative_list) > 0 else torch.empty(0)
                   
                   
                   Hc_positive = Hc_positive.unsqueeze(0) if len(Hc_positive.shape) == 1 else Hc_positive
                   Hc_negative = Hc_negative.unsqueeze(0) if len(Hc_negative.shape) == 1 else Hc_negative
                
                   
                   if Hc_positive.dim() == 2 and Hc_negative.dim() == 2:
                      if Hc_positive.size(1) == Hc_negative.size(1):
                         similarity_matrix = torch.mm(Hc_positive, Hc_negative.t())
                      else:
                         similarity_matrix = torch.empty((Hc_positive.size(0), Hc_negative.size(0)))
                   else:
                      similarity_matrix = torch.empty((Hc_positive.size(0), Hc_negative.size(0)))
                
                   valid_indices = torch.nonzero(similarity_matrix > 0.9)
                   valid_indices = valid_indices.transpose(0, 1)
                
                for i, label in enumerate(labels):
                    if label == c:
                    
                       Xz = torch.mean(Hi[valid_indices[0][valid_indices[1] == i]], dim=0)
                    
                       Xa = torch.mean(Hi[i], dim=0).detach()
                    
                       similarity = torch.cosine_similarity(Xz.unsqueeze(0), Xa.unsqueeze(0))
                    
                       Ln += similarity
                       Ln_count += 1
                if Ln_count > 0:
                       Ln /= Ln_count
                loss = CELoss(out, graph.y)
                
                all_loss += loss
            all_loss /= n_bw
            L = all_loss + Lp + Ln
            model_optimizer.zero_grad()
            L.backward()
            model_optimizer.step()
            val_mode()
            with torch.no_grad():
                
                
                train_acc = test_acc(train_loader, legnn)
                val_acc = test_acc(val_loader, legnn)
                
                ACC = 0.
                for graph in test_loader: 
                    graph.to(device)
                    
                    
                    
                    out = legnn(graph.x, graph.edge_index,graph.edge_attr, graph.batch)
                    
                    
                    
                    
                    ACC += torch.sum(out.argmax(-1).view(-1) == graph.y.view(-1)) / n_test_data
                    
                
                logger.info("Epoch [{:3d}/{:d}]  all_loss:{:2.3f}=[XE:{:2.3f}  IL:{:2.6f}]  "
                            "Train_ACC:{:.4f} Test_ACC[{:.4f}  ]  Val_ACC:{:.4f}  ".format(
                        epoch, args.epoch, all_loss, all_causal_loss, all_env_loss, 
                        train_acc, ACC, val_acc,
                        ))         
                
              
                if epoch >= args.pretrain:
                    if val_acc < last_val_acc:
                        cnt += 1
                    else:
                        cnt = 0
                        last_val_acc = val_acc
                if cnt >= 5:
                    logger.info("Early Stopping")
                    break  
                          
        all_info['causal_acc'].append(ACC)
        all_info['conf_acc'].append(ACC)
        all_info['train_acc'].append(train_acc)
        all_info['val_acc'].append(val_acc)
        torch.save(legnn.cpu(), osp.join(exp_dir, 'predictor-%d.pt' % seed))
        torch.save(att_net.cpu(), osp.join(exp_dir, 'attention_net-%d.pt' % seed))
        logger.info("=" * 100)

    logger.info("Causal ACC:{:.4f}-+-{:.4f}  Conf ACC:{:.4f}-+-{:.4f}  Train ACC:{:.4f}-+-{:.4f}  Val ACC:{:.4f}-+-{:.4f}  ".format(
                    torch.tensor(all_info['causal_acc']).mean(), torch.tensor(all_info['causal_acc']).std(),
                    torch.tensor(all_info['conf_acc']).mean(), torch.tensor(all_info['conf_acc']).std(),
                    torch.tensor(all_info['train_acc']).mean(), torch.tensor(all_info['train_acc']).std(),
                    torch.tensor(all_info['val_acc']).mean(), torch.tensor(all_info['val_acc']).std()
                ))
            