import os
import argparse
import torch
import random
from load_data import DataLoader
import json, re
from tqdm import tqdm

from base_model import BaseModel
import numpy as np
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, partial
import setproctitle
setproctitle.setproctitle('python-l')

parser = argparse.ArgumentParser()
parser.add_argument('--task_dir', type=str, default='./', help='the directory to dataset')
parser.add_argument('--dataset', type=str, default='S1_1', help='the directory to dataset')
parser.add_argument('--lamb', type=float, default=7e-4, help='set weight decay value')
parser.add_argument('--gpu', type=int, default=0, help='GPU id to load.')
parser.add_argument('--n_dim', type=int, default=128, help='set embedding dimension')
parser.add_argument('--lr', type=float, default=0.03, help='set learning rate')
parser.add_argument('--save_model', action='store_true')
parser.add_argument('--load_model', action='store_true')
parser.add_argument('--n_epoch', type=int, default=100, help='number of training epochs')
parser.add_argument('--n_batch', type=int, default=512, help='batch size')
parser.add_argument('--epoch_per_test', type=int, default=5, help='frequency of testing')
parser.add_argument('--test_batch_size', type=int, default=32, help='test batch size')
parser.add_argument('--out_file_info', type=str, default='', help='extra string for the output file name')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--vis', type=int, default=0)
parser.add_argument('--record', type=int, default=0)
parser.add_argument('--sub', type=int, default=0)
parser.add_argument('--knn', type=int, default=0)
parser.add_argument('--finger', type=int, default=0)
parser.add_argument('--des', type=int, default=0)
class options:
    def __init__():
        pass



if __name__ == '__main__':
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu)
    dataloader = DataLoader(args)
    eval_ent, eval_rel = dataloader.eval_ent, dataloader.eval_rel
    args.all_ent, args.all_rel, args.eval_rel = dataloader.all_ent, dataloader.all_rel, dataloader.eval_rel
    KG = dataloader.KG
    vKG = dataloader.vKG
    tKG = dataloader.tKG
    triplets = dataloader.triplets
    train_pos, train_neg = torch.LongTensor(triplets['train']).cuda(), None
    valid_pos, valid_neg = torch.LongTensor(triplets['valid']).cuda(), None
    test_pos,  test_neg  = torch.LongTensor(triplets['test']).cuda(), None

    if not os.path.exists('results'):
        os.makedirs('results')

    def run_model(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if args.dataset.startswith('S1'):
            args.lr = 0.001
            args.lamb = 0.0000001
            args.weight = 0.
            args.length = 3
            args.n_batch = 32
            args.n_dim = 64
            args.feat = 'M'
            args.n_epoch = 60
            model_path = 'S1_1_saved_model.pt'
        elif args.dataset.startswith('S2'):
            args.lr = 0.0005
            args.lamb = 0.00000001
            args.weight = 0.
            args.length = 3
            args.n_batch = 32
            args.n_dim = 96
            args.feat = 'M'
            args.n_epoch = 70
            model_path = 'S2_1_saved_model.pt'
        elif args.dataset.startswith('S0'):
            args.lr = 0.001
            args.lamb = 0.000001
            args.weight = 0
            args.n_batch = 128
            args.length = 3
            args.n_dim = 64
            args.feat = 'E'
            args.n_epoch = 60
            model_path = 'S0_saved_model.pt'

        cfg_str = 'feat:%s lr:%.6f lamb:%.8f n_batch:%d n_dim:%d layer:%d' % (args.feat, args.lr, args.lamb, args.n_batch, args.n_dim, args.length)
        with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
            f.write(cfg_str + '\n')
        
        if 'S1' in args.dataset:
            dataset = 'S1'
        elif 'S2' in args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'
        
        model = BaseModel(eval_ent, eval_rel, args)

        if args.vis:
            data_list = []
            model.load_model(model_path)
            f = open(f'DB_{dataset}path.jsonl', 'a+')
            for i in tqdm(range(len(test_pos))):
                h, t, r = test_pos[i]
                paths, weights, freq = model.visualize(test_pos[i], tKG)
                
                path_list = []
                for path_str in paths:
                    parts = path_str.split('\t', 1)  # 限制分割一次
                    score_str, content_with_newline = parts
                    content = content_with_newline.rstrip('\n')  # 去掉末尾的换行符
                    try:
                        score = float(score_str)
                    except ValueError:
                        continue  # 跳过无法转换为浮点数的分数
                    path_dict = {
                        "score": score,
                        "content": content
                    }
                    path_list.append(path_dict)
                
                data_dict = {
                    'id': i,
                    "test_triplet": [int(h), int(t), int(r)],
                    "paths": path_list
                }
                data_list.append(data_dict)

                f.write(json.dumps(data_dict) + '\n')
            print('test path write done\n')
            # exit()
        
        if args.record:
            model.load_model(model_path)
            t_f1, t_acc, t_kap = model.evaluate(test_pos, test_neg, tKG, record=1, filename=f'DB_{dataset}test_pred.json', save_pre=1)
            print(t_f1, t_acc, t_kap)
            print('record done')
            # exit()
        
        if args.sub:
            model.load_model(model_path)
            dataloader.shuffle_train()
            KG = dataloader.KG
            train_pos = torch.LongTensor(dataloader.train_data).cuda()
            print(train_pos.shape)  # 3057
            model.save_present(train_pos, KG)
            
            data_list = []
            f = open(f'DB_{dataset}train_path.jsonl', 'a+')
            for i in tqdm(range(len(train_pos))):
                h, t, r = train_pos[i]
                paths, weights, freq = model.visualize(train_pos[i], KG)  # 这不该用tkg，应该用kg
                
                path_list = []
                for path_str in paths:
                    parts = path_str.split('\t', 1)  # 限制分割一次
                    score_str, content_with_newline = parts
                    content = content_with_newline.rstrip('\n')  # 去掉末尾的换行符
                    try:
                        score = float(score_str)
                    except ValueError:
                        continue  # 跳过无法转换为浮点数的分数
                    path_dict = {
                        "score": score,
                        "content": content
                    }
                    path_list.append(path_dict)
                
                data_dict = {
                    'id': i,
                    "test_triplet": [int(h), int(t), int(r)],
                    "paths": path_list
                }
                data_list.append(data_dict)

                f.write(json.dumps(data_dict) + '\n')
            print('path write done\n')
            # exit()'''

        if args.knn:
            model.load_model(model_path)
            accuracy_at_ks = model.knn(test_pos, tKG, record=1, filename=f'DB_{dataset}_KNN_result.json')
            print(' '.join([f'acc@{k}={v:.3f}' for k, v in accuracy_at_ks.items()]))

        if args.des:
            accuracy_at_ks = model.descript(test_pos, train_file=f'DB_{dataset}train_path.jsonl',filename=f'DB_{dataset}_des_result.json')
            print(' '.join([f'acc@{k}={v:.3f}' for k, v in accuracy_at_ks.items()]))
            exit()

        best_acc = -1
        for e in range(args.n_epoch):
            dataloader.shuffle_train()
            KG = dataloader.KG
            train_pos = torch.LongTensor(dataloader.train_data).cuda()
            model.train(train_pos, None, KG)
            if (e+1) % args.epoch_per_test == 0:
                v_f1, v_acc, v_kap = model.evaluate(valid_pos, valid_neg, vKG)
                t_f1, t_acc, t_kap = model.evaluate(test_pos,  test_neg,  tKG)
                model.scheduler.step(v_f1)
                out_str = f'{e:<3d}' + '[Valid] f1:%.4f acc:%.4f kap:%.4f\t[Test] f1:%.4f acc:%.4f kap:%.4f' % (v_f1, v_acc, v_kap, t_f1, t_acc, t_kap)
                if t_f1 > best_acc:
                    best_acc = t_f1
                    best_str = out_str
                    if args.save_model:
                        model.save_model(best_str)
                print(out_str)
                with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
                    f.write(out_str + '\n')
        print('Best results:\t' + best_str)
        with open(os.path.join('results', args.dataset+'_eval.txt'), 'a+') as f:
            f.write('Best results:\t' + best_str + '\n\n')
        return -best_acc

    run_model(0)#args.seed)

