import argparse
import pprint
import warnings
import os
import dgl.base
import json
import torch
import random

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=dgl.base.DGLWarning)

import numpy as np
np.set_printoptions(precision=3, suppress=True)
from utils import setup_cuda, logger, set_seed, report_performance

from metagl import MetaGL
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import BertTokenizer, BertModel
def get_ori_task_types(merged_list):
    dir = "data"
    file = "testdev_balanced_questions.json"
    with open(os.path.join(dir, file), "r") as file:
        gqa = json.load(file)

    cnt = 0
    for _, v in gqa.items():
        # early exist
        if cnt > 10100:
            break
        cnt += 1
        imageId, question, types = v["imageId"], v["question"], v["types"]
        for item in merged_list:
            if item["imageId"] == imageId and item["question"] == question:
                item["types"] = types

    return merged_list

def build_dataset(ratio,train_ratio,if_add_program):
    if if_add_program:
        t_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        t_model = BertModel.from_pretrained("bert-base-uncased").to("cuda")

    structure2idx = {
        "verify": 0,
        "query": 1,
        "logical": 2,
        "choose": 3,
        "compare": 4,
    }

    ori_input_list = []
    prompt_file_path = os.path.join('', "data", "gqa_computation_graph_descrption.json")
    with open(prompt_file_path, "r") as file:
        ori_input_list = json.load(file)

    ## get original task types
    ori_input_list = get_ori_task_types(ori_input_list)

    train_structure_list,val_structure_list, test_structure_list = [], [],[]

    # X_t = np.load(f"data/t_embs_10k.npy")
    # X_v = np.load(f"data/v_embs_10k.npy")
    X = np.load(f"data/blip_embs_10k.npy")

    train_embedding_list=[]
    train_performance_list=[]
    val_embedding_list=[]
    val_performance_list=[]
    test_embedding_list=[]
    test_performance_list=[]

    instance_file = "data/gqa_model_selection_instance_results.json"


    with open(instance_file, "r") as file:
        data = json.load(file)

    train_random_file="data/train_random_list_{}.json".format(ratio)
    with open(train_random_file, "r") as file:
        train_random_list = json.load(file)

    val_random_file="data/val_random_list_{}.json".format(ratio)
    with open(val_random_file, "r") as file:
        val_random_list = json.load(file)

    sub0_num=0
    sub1_num=0
    sub2_num=0
    sub3_num=0
    sub4_num=0

    for cnt, (id, item_list) in tqdm(enumerate(data.items())):
        meta_data = ori_input_list[int(id)-1]
        structure,program = meta_data["types"]["structural"], meta_data["program"]
        structureId = structure2idx[structure]

        if if_add_program:
            t_inputs = t_tokenizer(program, return_tensors="pt", padding=True, truncation=True).to("cuda")
            with torch.no_grad():
                p_embs = t_model(**t_inputs).pooler_output
            p_embs=p_embs[0].to(torch.device("cpu"))

        ## get embeddings
        # t_embs = torch.tensor(X_t[cnt].tolist(), dtype=torch.float32)
        # v_embs = torch.tensor(X_v[cnt].tolist(), dtype=torch.float32)

        v_embs = torch.tensor(X[cnt].tolist(), dtype=torch.float32)


        ## only valid path
        flag = False
        for item in item_list:
            y = list(item.values())[-1]
            if y == 1: flag = True

        ## preprocess data
        if flag:

            if cnt in train_random_list:
                train_structure_list.append(structureId)
            elif cnt in val_random_list:
                val_structure_list.append(structureId)
            else:
                test_structure_list.append(structureId)
                if structureId == 0:
                    sub0_num += 1
                elif structureId == 1:
                    sub1_num += 1
                elif structureId == 2:
                    sub2_num += 1
                elif structureId == 3:
                    sub3_num += 1
                elif structureId == 4:
                    sub4_num += 1

            # embedding = t_embs + v_embs

            if if_add_program:
                embedding = v_embs+p_embs
            else:
                embedding = v_embs

            if torch.isnan(embedding).any():
                print("ssssssss")
            random_value=random.random()
            if cnt in train_random_list:
                if random_value <= (1-train_ratio):
                    train_embedding_list.append(embedding)
            elif cnt in val_random_list:
                val_embedding_list.append(embedding)
            else:
                test_embedding_list.append(embedding)

            performance=[]
            for idx, item in enumerate(item_list):
                vqa, loc, time, y = item.values()
                performance.append(y)
            if cnt in train_random_list:
                if random_value <= (1-train_ratio):
                    train_performance_list.append(performance)
            elif cnt in val_random_list:
                val_performance_list.append(performance)
            else:
                test_performance_list.append(performance)

    train_embedding_list = torch.stack(train_embedding_list)
    val_embedding_list = torch.stack(val_embedding_list)
    test_embedding_list = torch.stack(test_embedding_list)
    train_performance_list = torch.tensor(train_performance_list)
    val_performance_list = torch.tensor(val_performance_list)
    test_performance_list = torch.tensor(test_performance_list)
    train_structure_list = torch.tensor(train_structure_list)
    val_structure_list = torch.tensor(val_structure_list)
    test_structure_list = torch.tensor(test_structure_list)

    print("sub0_num:",sub0_num,"sub1_num:",sub1_num,"sub2_num:",sub2_num,"sub3_num:",sub3_num,"sub4_num:",sub4_num)
    return train_embedding_list,val_embedding_list,test_embedding_list,train_performance_list,val_performance_list,test_performance_list,train_structure_list,val_structure_list,test_structure_list,sub0_num,sub1_num,sub2_num,sub3_num,sub4_num

def main(args):
    args.device = torch.device('cuda')
    train_embedding_list,val_embedding_list,test_embedding_list,train_performance_list,val_performance_list,test_performance_list,train_structure_list,val_structure_list,test_structure_list,sub0_num,sub1_num,sub2_num,sub3_num,sub4_num=build_dataset(args.ratio,args.train_ratio,args.if_add_program)
    train_embedding_list = train_embedding_list.numpy()
    val_embedding_list = val_embedding_list.numpy()
    val_performance_list=np.asmatrix(val_performance_list.numpy())
    train_performance_list=np.asmatrix(train_performance_list.numpy())
    test_embedding_list = test_embedding_list.numpy()
    test_performance_list=np.asmatrix(test_performance_list.numpy())
    print("train",len(train_embedding_list),"val",len(val_embedding_list),"test",len(test_embedding_list))

    _, num_models = train_performance_list.shape
    num_meta_feats = train_embedding_list.shape[1]

    metagl = MetaGL(
        num_models=num_models,
        metafeats_dim=num_meta_feats,
        epochs=args.epochs,
        device=args.device,
        batch_size=args.batch_size
    )
    logger.info(f"Running MetaGL...")
    set_seed(args.seed)

    txt_name = 'logs/tmp.txt'
    log = open(txt_name, mode = "a+", encoding = "utf-8")
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~", file=log)
    print("if_add_program", args.if_add_program, "seed", args.seed, "batch_size", args.batch_size, "epoch", args.epochs,"lr", args.lr, file=log)
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~", file=log)
    log.close()

    loss_list, acc_list,acc,sub0_acc,sub1_acc,sub2_acc,sub3_acc,sub4_acc=metagl.train_predict(train_embedding_list,val_embedding_list,test_embedding_list,train_performance_list,val_performance_list,test_performance_list,train_structure_list,val_structure_list,test_structure_list,args.lr,0,txt_name,sub0_num,sub1_num,sub2_num,sub3_num,sub4_num)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=int, default=0,
                        help="which GPU to use. Set to -1 to use CPU.")
    parser.add_argument("--seed", type=int, default=1337,
                        help="random seed")
    parser.add_argument("--perf-nan-perc", type=float, default=0.0,
                        help="percentage of nans in the performance matrix")
    parser.add_argument("--k-fold-n-splits", type=int, default=5,
                        help="number of splits for k-fold cross validation")
    parser.add_argument("--epochs", type=int, default=10,
                        help="maximum number of training epochs")
    parser.add_argument("--ratio", type=float, default=0.6)
    parser.add_argument("--lr", type=float, default=0.8)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--train_ratio", type=float, default=0.0)
    parser.add_argument("--mask_ratio", type=float, default=0.0)

    parser.add_argument("--if_mask_ratio", type=int, default=0)
    parser.add_argument("--if_add_program", type=int, default=0)

    args = parser.parse_args()
    setup_cuda(args)
    print("\n[Settings]\n" + pprint.pformat(args.__dict__))
    set_seed(args.seed)


    main(args)
