import os
import argparse
import torch
import random
from load_data import DataLoader
import json
from tqdm import tqdm

from base_model import BaseModel
import numpy as np
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, partial
import setproctitle
setproctitle.setproctitle('python-two')

parser = argparse.ArgumentParser()
parser.add_argument('--task_dir', type=str, default='./', help='the directory to dataset')
parser.add_argument('--dataset', type=str, default='S1_1', help='the directory to dataset')
parser.add_argument('--lamb', type=float, default=7e-4, help='set weight decay value')
parser.add_argument('--gpu', type=int, default=0, help='GPU id to load.')
parser.add_argument('--n_dim', type=int, default=128, help='set embedding dimension')
parser.add_argument('--save_model', action='store_true')
parser.add_argument('--load_model', action='store_true')
parser.add_argument('--lr', type=float, default=0.03, help='set learning rate')
parser.add_argument('--n_epoch', type=int, default=100, help='number of training epochs')
parser.add_argument('--n_batch', type=int, default=512, help='batch size')
parser.add_argument('--epoch_per_test', type=int, default=5, help='frequency of testing')
parser.add_argument('--test_batch_size', type=int, default=64, help='test batch size')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--vis', type=int, default=0)
parser.add_argument('--record', type=int, default=0)
parser.add_argument('--sub', type=int, default=0)
parser.add_argument('--knn', type=int, default=0)
parser.add_argument('--finger', type=int, default=0)
parser.add_argument('--des', type=int, default=0)
class options:
    def __init__():
        pass



if __name__ == '__main__':
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu)
    dataloader = DataLoader(args)
    eval_ent, eval_rel = dataloader.eval_ent, dataloader.eval_rel
    args.all_ent, args.all_rel, args.eval_rel = dataloader.all_ent, dataloader.all_rel, dataloader.eval_rel
    KG = dataloader.KG
    vKG = dataloader.vKG
    tKG = dataloader.tKG
    pos_triplets, neg_triplets = dataloader.pos_triplets, dataloader.neg_triplets
    train_pos, train_neg = torch.LongTensor(pos_triplets['train']).cuda(), torch.LongTensor(neg_triplets['train']).cuda()
    valid_pos, valid_neg = torch.LongTensor(pos_triplets['valid']).cuda(), torch.LongTensor(neg_triplets['valid']).cuda()
    test_pos,  test_neg  = torch.LongTensor(pos_triplets['test']).cuda(),  torch.LongTensor(neg_triplets['test']).cuda()

    args.ent_pair = dataloader.ent_pair
    args.train_ent = list(dataloader.train_ent)

    if not os.path.exists('results'):
        os.makedirs('results')

    def run_model(seed):
        print('seed: {}'.format(args.seed))
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.dataset.startswith('S1'):
            args.lr = 0.003 
            args.lamb = 0.00000001
            args.n_batch = 32
            args.n_dim = 64 
            args.length = 3
            args.feat = 'E' 
            model_path = 'S1_1_saved_model.pt'

        elif args.dataset.startswith('S2'):
            args.lr = 0.014443
            args.lamb = 0.0001
            args.n_batch = 96
            args.n_dim = 32
            args.length = 4
            args.feat = 'M'
            model_path = 'S2_1_saved_model.pt'

        elif args.dataset == 'S0':
            args.lr = 0.009581
            args.lamb = 0.000001
            args.n_dim = 64
            args.n_batch = 64#128
            args.length = 2
            args.feat = 'E'
            model_path = 'S0_saved_model.pt'
        

        cfg_str = 'feat:%s lr:%.6f lamb:%.8f n_batch:%d n_dim:%d layer:%d' % (args.feat, args.lr, args.lamb, args.n_batch, args.n_dim, args.length)
        with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
            f.write(cfg_str + '\n')


        if 'S1' in args.dataset:
            dataset = 'S1'
        elif 'S2' in args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'

        model = BaseModel(eval_ent, eval_rel, args)
        if args.vis:
            data_list = []
            model.load_model(model_path)
            f = open(f'TS_{dataset}path.jsonl', 'a+')
            for i in tqdm(range(len(test_pos))):
                h, t, rs = test_pos[i][0], test_pos[i][1], test_pos[i][2:]
                paths, weights, freq = model.visualize(test_pos[i], tKG)
                
                path_list = []
                for path_str in paths:
                    parts = path_str.split('\t', 1)  # 限制分割一次
                    score_str, content_with_newline = parts
                    content = content_with_newline.rstrip('\n')  # 去掉末尾的换行符
                    try:
                        score = float(score_str)
                    except ValueError:
                        continue  # 跳过无法转换为浮点数的分数
                    path_dict = {
                        "score": score,
                        "content": content
                    }
                    path_list.append(path_dict)
                
                indices = rs.nonzero(as_tuple=True)[0]
                data_dict = {
                    'id': i,
                    "test_triplet": [int(h), int(t), [int(r) for r in indices]],
                    "paths": path_list
                }
                data_list.append(data_dict)

                f.write(json.dumps(data_dict) + '\n')
            print('test path write done\n')
            # exit()
        
        if args.record:
            model.load_model(model_path)
            t_recall, t_ndcg = model.evaluate(test_pos, test_neg, tKG, record=1, filename=f'TS_{dataset}test_pred.json', save_pre=1)
            print( t_recall, t_ndcg)
            print('record done')
            # exit()

        if args.sub:
            model.load_model(model_path)
            dataloader.shuffle_train(ratio=0.6)  # S2: 0.6
            KG = dataloader.KG
            train_pos, train_neg = dataloader.train_pos, dataloader.train_neg
            model.save_present(train_pos, train_neg, KG)

            data_list = []
            train_pos = torch.LongTensor(train_pos).cuda()
            f = open(f'TS_{dataset}train_path.jsonl', 'a+')
            for i in tqdm(range(len(train_pos))):
                # train_pos = torch.LongTensor(train_pos).cuda()
                h, t, rs = train_pos[i][0], train_pos[i][1], train_pos[i][2:]
                paths, weights, freq = model.visualize(train_pos[i], KG)  # 这不该用tkg，应该用kg
                
                path_list = []
                for path_str in paths:
                    parts = path_str.split('\t', 1)  # 限制分割一次
                    score_str, content_with_newline = parts
                    content = content_with_newline.rstrip('\n')  # 去掉末尾的换行符
                    try:
                        score = float(score_str)
                    except ValueError:
                        continue  # 跳过无法转换为浮点数的分数
                    path_dict = {
                        "score": score,
                        "content": content
                    }
                    path_list.append(path_dict)
                
                indices = rs.nonzero(as_tuple=True)[0]
                data_dict = {
                    'id': i,
                    "test_triplet": [int(h), int(t), [int(r) for r in indices]],
                    "paths": path_list
                }
                data_list.append(data_dict)

                f.write(json.dumps(data_dict) + '\n')
            print('path write done\n')
            # exit()

        if args.knn:
            model.load_model(model_path)
            recall_at_ks = model.knn(test_pos, tKG, record=1, filename=f'TS_{dataset}_KNN_result.json')
            print(' '.join([f'recall@{k}={v:.3f}' for k, v in recall_at_ks.items()]))
            exit()

        if args.finger:
            # model.load_model('S0_saved_model-new.pt')
            recall_at_ks = model.finger(test_pos, train_file='TS_S0train_path.jsonl', filename='TS_S0_finger_ex_result.json')
            for k in recall_at_ks.keys():
                print(f'recall@{k}: {recall_at_ks[k]:.4f} ')
            exit()
        

        if args.des:
            recall_at_ks = model.descript(test_pos, train_file=f'TS_{dataset}train_path.jsonl',filename=f'TS_{dataset}_des_result.json')
            print(' '.join([f'recall@{k}={v:.3f}' for k, v in recall_at_ks.items()]))
            exit()

        best_acc = -1
        for e in range(args.n_epoch):
            dataloader.shuffle_train()
            KG = dataloader.KG
            train_pos, train_neg = dataloader.train_pos, dataloader.train_neg
            model.train(train_pos, train_neg, KG)
            if (e+1) % args.epoch_per_test == 0:
                #v_roc, v_pr, v_ap = model.evaluate(valid_pos, valid_neg, vKG)
                #t_roc, t_pr, t_ap = model.evaluate(test_pos,  test_neg,  tKG)
                #out_str = f'{e:<3d}' + '\t[Valid] ROC-AUC:%.4f PR-AUC:%.4f AP:%.4f\t [Test] ROC-AUC:%.4f PR-AUC:%.4f AP:%.4f' % (v_roc, v_pr, v_ap, t_roc, t_pr, t_ap)
                v_recall, v_ndcg = model.evaluate(valid_pos, valid_neg, vKG)
                t_recall, t_ndcg = model.evaluate(test_pos, test_neg, tKG)
                out_str = 'epoch:%d  vR@10:%.4f  vN@10:%.4f  tR@10:%.4f  tN@10:%.4f' % (e+1, v_recall[10], v_ndcg[10], t_recall[10], t_ndcg[10])
                t_pr = t_recall[10]
                if t_pr > best_acc:
                    best_acc = t_pr
                    best_str = out_str
                    # if args.save_model:
                    # if t_pr > 0.07:
                    model.save_model(best_str)
                print(out_str)
                with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
                    f.write(out_str + '\n')
        print('Best results:\t' + best_str)
        with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
            f.write('Best results:\t' + best_str + '\n\n')
        return -best_acc

    run_model(0)
    

