import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from corrupte import*
import numpy as np
from data import load_data,preprocess_features,preprocess_adj,normalize_S
from utils import masked_loss,acc
import higher
import scipy
import scipy.sparse as sp
import argparse
from model import GCN,ANet
import os

def main(args):
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    device = torch.device('cuda')

    #load the datasets
    if args.dataset in ['cora','citeseer','pubmed']:
        adj,features,y,train_mask,val_mask,test_mask,clean_mask = load_data(args)
    else:
        print('Incorrect datasets name.')

    adj = sp.coo_matrix(adj)
    y = torch.from_numpy(y)
    y_label = y.argmax(dim=1).to(device)
    num_classes = y.shape[1]
    num_nodes = y.shape[0]
    features = preprocess_features(features)
    supports = preprocess_adj(adj)

    train_mask = torch.from_numpy(train_mask).to(device).bool()
    val_mask = torch.from_numpy(val_mask).to(device).bool()
    test_mask = torch.from_numpy(test_mask).to(device).bool()
    clean_mask = torch.from_numpy(clean_mask).to(device).bool()

    if args.corruption_type == 'uniform':
        C = uniform_mix_C(args.corruption_prob, num_classes)
    elif args.corruption_type == 'flip':
        C = flip_labels_C(args.corruption_prob, num_classes)

    train_num = torch.sum(train_mask).item()
    for i in range(train_num):
        y_label[i] = np.random.choice(num_classes, p=C[y_label[i]])

    i = torch.from_numpy(features[0]).long().to(device)
    v = torch.from_numpy(features[1]).to(device)
    feature = torch.sparse.FloatTensor(i.t(), v, features[2]).float().to(device)

    i = torch.from_numpy(supports[0]).long().to(device)
    v = torch.from_numpy(supports[1]).to(device)
    support = torch.sparse.FloatTensor(i.t(), v, supports[2]).float().to(device)

    num_features_nonzero = feature._nnz()
    feat_dim = feature.shape[1]

    net = GCN(feat_dim,args.hidden,num_classes,num_features_nonzero).to(device)
    A_Net = ANet(2,args.ANet_dim,1).to(device)
    optimizer = optim.Adam(net.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay)
    optimizer_ANet = optim.Adam(A_Net.parameters(),args.A_learning_rate,weight_decay=args.A_weight_decay)
    eps = np.finfo(float).eps
    y_real = y.argmax(dim=1).to(device)

    label_onehot = torch.zeros(num_nodes,num_classes).to(device)
    #convert noisy labels to one-hot label
    ones = torch.eye(num_classes).to(device)
    y_label_onehot = torch.zeros(num_nodes,num_classes).to(device)
    y_label_onehot[train_mask] = ones.index_select(0,y_label[train_mask])

    best_acc = 0
    for epoch in range(args.epochs): 
        net.train() 
        optimizer.zero_grad()
        out,_ = net((feature, support))

        v = torch.mean((out[adj.row] - out[adj.col])**2,1)
        v = 1/(v+eps)
        v = v.detach().cpu().numpy() 
        S = sp.coo_matrix((v,(adj.row,adj.col)),shape=adj.shape)
        S = normalize_S(S)

        label_onehot[clean_mask] = ones.index_select(0,y_label[clean_mask].to(device))
        label_sp = sp.lil_matrix(label_onehot.detach().cpu().numpy())
        clean_mask_np = clean_mask.detach().cpu().numpy()
        for j in range(args.lpa_iters):
            if j == 0:
                Z = S.dot(label_sp)  
            else:
                Z = S.dot(Z)
            Z[clean_mask_np] = label_sp[clean_mask_np]
        Z = torch.from_numpy(Z.toarray()).to(device)
        Z = F.softmax(Z,dim=1)
        Z_train = Z.argmax(dim=1)[train_mask]
        tmp_mask = Z_train == y_label[train_mask]
        clean_mask[:train_num] = tmp_mask
        train_sel_mask = torch.zeros(y.shape[0]).bool()
        train_sel_mask[:train_num] = tmp_mask
        train_sel_mask = train_sel_mask.to(device)
        loss_1 = masked_loss(out,y_label,train_sel_mask)
        loss_1.backward()
        optimizer.step()
        train_left_mask = torch.zeros(y.shape[0]).bool().to(device)
        train_left_mask[:train_num] = ~tmp_mask

        with higher.innerloop_ctx(net, optimizer) as (meta_net, meta_opt):     
            meta_out,_ = meta_net((feature, support))
            meta_cost_1 = F.cross_entropy(meta_out[train_left_mask], y_label[train_left_mask], reduction='none')
            meta_cost_1 = torch.reshape(meta_cost_1, (len(meta_cost_1), 1))
            meta_cost_2 = torch.sum(F.softmax(meta_out[train_left_mask],1)*(F.log_softmax(meta_out[train_left_mask],1)-torch.log(Z[train_left_mask])),1)
            meta_cost_2 = torch.reshape(meta_cost_2, (len(meta_cost_2), 1))
            meta_cost_a = torch.cat((meta_cost_1,meta_cost_2),1).float()
            a_lambda = A_Net(meta_cost_a)
            y_mul = a_lambda * y_label_onehot[train_left_mask] + (1-a_lambda) * Z[train_left_mask]
            meta_cost = torch.sum(y_mul*(torch.log(y_mul)-F.log_softmax(meta_out[train_left_mask],1)),1).mean()
            meta_net.zero_grad()
            meta_opt.step(meta_cost)
            meta_out,_ = meta_net((feature, support))
            meta_loss = masked_loss(meta_out,y_label,clean_mask)
            optimizer_ANet.zero_grad()
            meta_loss.backward()
            optimizer_ANet.step()

        out,_ = net((feature, support))
        with torch.no_grad():
            lambda_new = A_Net(meta_cost_a)

        y_mul = lambda_new * y_label_onehot[train_left_mask] + (1-lambda_new) * Z[train_left_mask]

        y_mul = y_mul.detach()
        loss_2 = torch.sum(F.softmax(out[train_left_mask],1)*(F.log_softmax(out[train_left_mask],1)-torch.log(y_mul)),1).mean()
        loss_2.backward()
        optimizer.step()
        train_acc = acc(out, y_label, train_mask)
        val_acc = acc(out, y_label, val_mask)
        train_loss = masked_loss(out,y_label,train_mask)
        val_loss = masked_loss(out,y_label,val_mask)

        net.eval()
        out,_ = net((feature, support))
        test_acc = acc(out, y_label, test_mask)
        print('Epoch: {:04d}'.format(epoch+1),
            'loss_train: {:.4f}'.format(train_loss.item()),
            'acc_train: {:.4f}'.format(train_acc),
            'loss_val: {:.4f}'.format(val_loss.item()),
            'acc_val: {:.4f}'.format(val_acc),
            'acc_test: {:.4f}'.format(test_acc))
        if test_acc >= best_acc:
            best_acc = test_acc
    print('best accuracy:', best_acc)

if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--seed', default=45)
    args.add_argument('--dataset', default='cora',help='The dataset:cora,citeseer,pubmed')
    args.add_argument('--learning_rate', type=float, default=0.01)
    args.add_argument('--epochs', type=int, default=300)
    args.add_argument('--hidden', type=int, default=16,help='The number of neurons in hidden layer of GNNs')
    args.add_argument('--dropout', type=float, default=0.5)
    args.add_argument('--weight_decay', type=float, default=5e-4,help='The weight decay of GNNs')
    args.add_argument('--A_weight_decay', type=float, default=1e-4,help='The weight decay of Aggregation Net')
    args.add_argument('--A_learning_rate', type=float, default=1e-4,help='The learning rate of Aggregation Net')
    args.add_argument('--ANet_dim',type=int,default=64,help='The dimension of middle layer of Aggregation Net')
    args.add_argument('--corruption_prob',type=float,default=0.4,help='The ratio of label noise')
    args.add_argument('--corruption_type',type=str,default='uniform',help='The type of label noise:uniform,flip')
    args.add_argument('--clean_label_num',type=int,default=28,help='The number of nodes in initial clean sets')
    args.add_argument('--lpa_iters',type=int,default=50)
    args = args.parse_args()
    main(args)

       
 
   




