import os.path as osp
import sys
import argparse
sys.path.append((osp.abspath(osp.dirname(__file__)).split('src')[0] + 'src'))

from utils import *
from models.GraphHD import GraphHDConfig
from models.GraphHD.model import _similarity
import torch as th
import warnings
warnings.filterwarnings('ignore')
from models.GraphHD.h_cluster import run_hkmeans_faiss


def compute_features(model, dataloader, cf):
    graph_features = th.zeros(cf.data_len, cf.n_hidden).to(cf.compute_dev)
    for step, data in enumerate(dataloader):
        data = data.to(cf.compute_dev)
        g = model(data.x, data.edge_index, data.edge_attr, data.batch)
        graph_features[data.id] = g

    return graph_features.cpu().numpy()

def run_cluster(model, dataloader, n_protos, cf):
    print('Clustering...')
    graph_features = compute_features(model, dataloader, cf)
    cluster_result = run_hkmeans_faiss(graph_features, n_protos, cf)
    prototypes = cluster_result['centroids']

    return prototypes


def test_teacher(loader, teacher_model, prototypes, augmentor, cf, teacher_proj):

    aug1, aug2 = augmentor
    correct_cnt_dict = {}
    for h in range(cf.h_level):
        correct_cnt_dict[h] = 0

    for step, data in enumerate(loader):
        data = data.to(cf.compute_dev)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = th.ones((num_nodes, 1), dtype=th.float32, device=cf.compute_dev)

        x, edge_index, edge_attr, batch, id = data.x, data.edge_index, data.edge_attr, data.batch, data.id
        x1, edge_index1, edge_attr1 = aug1(x, edge_index, edge_attr)
        x2, edge_index2, edge_attr2 = aug2(x, edge_index, edge_attr)

        tg1 = teacher_model(x1, edge_index1, edge_attr1, batch)
        tg2 = teacher_model(x2, edge_index2, edge_attr2, batch)
        if teacher_proj is not None:
            tg1, tg2 = [teacher_proj(x) for x in [tg1, tg2]]

        for l in range(cf.h_level):
            tp1 = _similarity(tg1, prototypes[l])
            tp2 = _similarity(tg2, prototypes[l])

            t1_pseudo_label = th.argmax(tp1, dim=1)
            t2_pseudo_label = th.argmax(tp2, dim=1)

            correct_cnt = (t1_pseudo_label == t2_pseudo_label).long().sum()
            correct_cnt_dict[l] += correct_cnt

            print(f'Level {l}, correct count: {correct_cnt}, correct ratio: {correct_cnt/tp1.shape[0]:.4f}')

    return correct_cnt_dict

def test_student(loader, student_model, prototypes, augmentor, cf):

    aug1, aug2 = augmentor
    correct_cnt_dict = {}
    for h in range(cf.h_level):
        correct_cnt_dict[h] = 0

    for step, data in enumerate(loader):
        data = data.to(cf.compute_dev)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = th.ones((num_nodes, 1), dtype=th.float32, device=cf.compute_dev)

        x, edge_index, edge_attr, batch, id = data.x, data.edge_index, data.edge_attr, data.batch, data.id
        x1, edge_index1, edge_attr1 = aug1(x, edge_index, edge_attr)
        x2, edge_index2, edge_attr2 = aug2(x, edge_index, edge_attr)

        s1_list = student_model(x1, edge_index1, edge_attr1, batch)
        s2_list = student_model(x2, edge_index2, edge_attr2, batch)

        for l in range(cf.h_level):
            sg1 = s1_list[l]
            sg2 = s2_list[l]

            sp1 = _similarity(sg1, prototypes[l])
            sp2 = _similarity(sg2, prototypes[l])

            s1_pseudo_label = th.argmax(sp1, dim=1)
            s2_pseudo_label = th.argmax(sp2, dim=1)

            correct_cnt = (s1_pseudo_label == s2_pseudo_label).long().sum()
            correct_cnt_dict[l] += correct_cnt

            print(f'Step {step}, Level {l}, correct count: {correct_cnt}, correct ratio: {correct_cnt/sp1.shape[0]:.4f}')

    return correct_cnt_dict

def cross_check(loader, teacher_model, student_model, teacher_prototypes, student_prototypes, cf):

    correct_cnt_dict = {}
    for h in range(cf.h_level):
        correct_cnt_dict[h] = 0

    for step, data in enumerate(loader):
        data = data.to(cf.compute_dev)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = th.ones((num_nodes, 1), dtype=th.float32, device=cf.compute_dev)

        x, edge_index, edge_attr, batch, id = data.x, data.edge_index, data.edge_attr, data.batch, data.id
        tg = teacher_model(x, edge_index, edge_attr, batch)
        sg_list = student_model(x, edge_index, edge_attr, batch)

        for l in range(cf.h_level):
            sg = sg_list[l]

            tp = _similarity(tg, teacher_prototypes[l])
            sp = _similarity(sg, student_prototypes[l])

            t_pseudo_label = th.argmax(tp, dim=1)
            s_pseudo_label = th.argmax(sp, dim=1)

            correct_cnt = (t_pseudo_label == s_pseudo_label).long().sum()
            correct_cnt_dict[l] += correct_cnt

            print(f'Step {step}, Level {l}, correct count: {correct_cnt}, correct ratio: {correct_cnt / tp.shape[0]:.4f}')

    return correct_cnt_dict


@time_logger
def pretrain_UnbiasGCL(args):
    # ! Init Arguments
    exp_init(args.gpus, args.log_on)
    # ! Import packages
    import torch as th
    from models.GraphHD.loader import MoleculeDataset
    from torch_geometric.loader import DataLoader
    import GCL.augmentors as A
    from models.GraphHD.model import Teacher_Model, Student_Model, ProjectNet, Student_ProjectNet

    cf = GraphHDConfig(args)
    cf.compute_dev = th.device("cuda:0" if args.gpus >= 0 and th.cuda.is_available() else "cpu")

    dataset = MoleculeDataset("dataset/" + args.prt_dataset, dataset=args.prt_dataset)
    cf.feat_dim = max(dataset.num_features, 1)
    cf.n_class = dataset.num_classes
    print(cf)
    data_len = len(dataset)
    cf.data_len = data_len

    ## ! Train the teacher model
    dataloader = DataLoader(dataset, batch_size=cf.batch_size, shuffle=True, num_workers=4)
    aug1 = A.Identity()
    aug2 = A.RandomChoice([A.NodeDropping(pn=0.2), A.EdgeRemoving(pe=0.2)], 1)
    teacher_model = Teacher_Model(teacher_model=cf.teacher_model, cf=cf).to(cf.compute_dev)

    teacher_file = cf.teacher_file + f"{cf.teacher_model}_rep" + ".pth"
    teacher_checkpoint = th.load(teacher_file, map_location=cf.compute_dev)
    teacher_model.encoder.load_state_dict(teacher_checkpoint['encoder'])
    # teacher_model.encoder.load_state_dict(teacher_checkpoint)

    for p in teacher_model.parameters():
        p.requires_grad = False

    # if 'proj' in teacher_checkpoint.keys():
    #     teacher_proj = ProjectNet(cf.n_hidden).to(cf.compute_dev)
    #     teacher_proj.load_state_dict(teacher_checkpoint['proj'])
    #     for p in teacher_proj.parameters():
    #         p.requires_grad = False
    # else:
    #     teacher_proj = None

    teacher_proj = None

    if 'proto' in teacher_checkpoint.keys():
        teacher_prototypes = teacher_checkpoint['prototypes']
        cf.h_level = len(teacher_prototypes)
        for h in range(cf.h_level):
            teacher_prototypes[h].requires_grad = False
    else:
        n_protos = [int(x) for x in cf.n_protos.split('_')]
        cf.h_level = len(n_protos)
        teacher_prototypes = run_cluster(teacher_model, dataloader, n_protos, cf)


    teacher_result = test_teacher(dataloader, teacher_model, teacher_prototypes, augmentor=(aug1, aug2), cf=cf, teacher_proj=teacher_proj)
    for h in range(cf.h_level):
        teacher_result[h] = teacher_result[h] / data_len

    print(teacher_result)

    # student_model = Student_Model(h_level=3, cf=cf).to(cf.compute_dev)
    # student_file = cf.student_file  + f"_scp{20}" + ".pth"
    # print(student_file)
    # student_checkpoint = th.load(student_file)
    # student_model.encoders.load_state_dict(student_checkpoint['encoders'])
    # student_model.projs.load_state_dict(student_checkpoint['projs'])
    # student_prototypes = student_checkpoint['protos']
    #
    # for p in student_model.parameters():
    #     p.requires_grad = False

    # student_result = test_student(dataloader, student_model, student_prototypes, augmentor=(aug1, aug2), cf=cf)
    # for h in range(cf.h_level):
    #     student_result[h] = student_result[h]/data_len

    # print(student_result)

    # check_result = cross_check(dataloader, teacher_model, student_model, teacher_prototypes, student_prototypes, cf)
    # print(check_result)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Training settings")
    parser = GraphHDConfig.add_exp_setting_args(parser)
    exp_args = parser.parse_known_args()[0]
    parser = GraphHDConfig(exp_args).add_model_specific_args(parser)
    args = parser.parse_args()
    pretrain_UnbiasGCL(args)
