import torch
import numpy as np
import json
import torch.nn.functional as F

from utils import batch_by_size
from tqdm import tqdm
from torch.optim import Adam
from models import EmerGNN
from sklearn.metrics import f1_score, cohen_kappa_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter
import pickle

class BaseModel(object):
    def __init__(self, eval_ent, eval_rel, args, entity_vocab=None, relation_vocab=None):
        self.model = EmerGNN(eval_ent, eval_rel, args)
        if args.load_model:
            state_dict = torch.load(args.dataset+'_saved_model.pt')
            self.model.load_state_dict(state_dict)
        self.model.cuda()

        self.eval_ent = eval_ent
        self.eval_rel = eval_rel
        self.all_rel = args.all_rel
        self.args = args
        print('eval_rel', self.eval_rel, 'all_rel', self.all_rel)

        self.optimizer = Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.lamb)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max')

        # self.entity_vocab = entity_vocab
        # self.relation_vocab = relation_voc
        with open('data/DB_all_node.json', 'r') as f:
            entity_vocab = json.load(f)
            # 将键转换为整数，以便直接使用整数id进行查找
            self.entity_vocab = {int(k): v.get('name', 'Unknown') for k, v in entity_vocab.items()}
        
        with open('data/DB_all_rel.json', 'r') as f:
            relation_vocab = json.load(f)
            # 提取'name'字段，并将键转换为整数
            self.relation_vocab = {int(k): v.get('name', 'Unknown') for k, v in relation_vocab.items()}

    def train(self, train_pos, train_neg, KG):
        head, tail, label = train_pos[:,0], train_pos[:,1], train_pos[:,2]
        n_train = len(head)
        n_batch = self.args.n_batch

        loss_epoch = 0
        self.model.train()
        for h, t, r in tqdm(batch_by_size(n_batch, head, tail, label, n_sample=n_train),
            ncols=100, leave=False, total=len(head)//n_batch+int(len(head)%n_batch>0)):
            self.model.zero_grad()
            ht_embed = self.model.enc_ht(h, t, KG)
            scores = self.model.enc_r(ht_embed)
            p_score = scores[torch.arange(len(r)).cuda(), r]
            n_score = scores
            max_n = torch.max(n_score, 1, keepdim=True)[0]
            loss = -p_score + max_n + torch.log(torch.sum(torch.exp(n_score - max_n), 1))
            loss = loss.sum()

            loss.backward()
            self.optimizer.step()
            loss_epoch += loss.item()

    def save_present(self, train_pos, KG):

        head, tail, label = train_pos[:,0], train_pos[:,1], train_pos[:,2]
        n_train = len(head)
        n_batch = self.args.n_batch

        # Initialize lists to collect data
        h_list = []
        t_list = []
        r_list = []
        ht_embed_list = []

        for h, t, r in tqdm(batch_by_size(n_batch, head, tail, label, n_sample=n_train),
                            ncols=100, leave=False, total=len(head)//n_batch + int(len(head)%n_batch > 0)):
            ht_embed = self.model.enc_ht(h, t, KG)

            # Collect data from the current batch
            h_list.append(h.cpu().numpy())  # Assuming tensors are on GPU; move to CPU and convert to numpy
            t_list.append(t.cpu().numpy())
            r_list.append(r.cpu().numpy())
            ht_embed_list.append(ht_embed.detach().cpu().numpy())

        # Concatenate lists into numpy arrays
        h_array = np.concatenate(h_list)
        t_array = np.concatenate(t_list)
        r_array = np.concatenate(r_list)
        ht_embed_array = np.concatenate(ht_embed_list, axis=0)
        print(h_array.shape, ht_embed_array.shape) # (35256,) (35256, 128) 
        
        # Save the data using numpy .npz format
        if 'S1' in self.args.dataset:
            dataset = 'S1'
        elif 'S2' in self.args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'
        np.savez(f'{dataset}-training_data.npz', h=h_array, t=t_array, r=r_array, ht_embed=ht_embed_array)
        

    def knn(self, test_pos, KG, record=0, filename=None):
        # heads, tails, relas = test_pos[:,0], test_pos[:,1], test_pos[:,2]
        # batch_size = self.args.test_batch_size
        # num_batch = len(heads) // batch_size + int(len(heads)%batch_size>0)

        if 'S1' in self.args.dataset:
            dataset = 'S1'
        elif 'S2' in self.args.dataset:
            dataset = 'S2'
        else:
            dataset = 'S0'

        self.model.eval()
        training_data = np.load(f'{dataset}-training_data.npz')
        train_ht_embed = torch.FloatTensor(training_data['ht_embed']).cuda()
        train_labels = training_data['r']
        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        accuracy_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        # for i in tqdm(range(num_batch)):
        #     start = i * batch_size
        #     end = min((i+1)*batch_size, len(heads))
        #     batch_h = heads[start:end].cuda()
        #     batch_t = tails[start:end].cuda()
        #     ht_embed = self.model.enc_ht(batch_h, batch_t, KG)
        test_data = np.load(f'{dataset}-test_data.npz')
        ht_embed = torch.FloatTensor(test_data['ht_embed']).cuda()
        label = test_data['r']
    
        # use KNN from training_data
        for j in range(ht_embed.shape[0]):
            test_embed = ht_embed[j]
            # 计算测试样本与训练集所有样本的余弦相似度
            # similarity =  torch.nn.functional.cosine_similarity(test_embed.reshape(1, -1), train_ht_embed)
            # similarity = similarity.detach().cpu().numpy()
            # sorted_similarity = np.argsort(similarity)
            # 欧氏距离
            similarity = torch.nn.functional.pairwise_distance(test_embed.reshape(1, -1), train_ht_embed, p=2)
            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)

            # print(similarity.shape)  # (35256,)
            for k in k_values:
                # 获取最相似的 k 个训练样本的索引
                k_nearest_indices = sorted_similarity[:k]
                # 获取最相似的 k 个训练样本的标签
                k_nearest_labels = train_labels[k_nearest_indices]
                # print(k, k_nearest_labels)
                # 统计标签的频率并选择出现频率最高的标签作为预测结果
                most_common_label = Counter(k_nearest_labels).most_common(1)[0][0]
                pred_at_ks[k].append(most_common_label)
            
            neighbors.append(sorted_similarity[:30])
            neighbors_labels.append(train_labels[sorted_similarity[:30]])
            

        # label = relas.data.cpu().numpy()
        for k in k_values:
            pred = np.array(pred_at_ks[k])
            print(len(pred), len(label))
            accuracy_at_ks[k] = np.sum(pred == label) / len(pred)

        if record > 0:
            # 记录在一个json文件里 
            records = []
            for idx in range(len(neighbors)):
                neighbors_indices = neighbors[idx]
                record_entry = {
                    'id': idx,
                    'K-neighbor': neighbors_indices.tolist(),  # S1_train_data.npz中的索引
                    'K-neighbor_labels': neighbors_labels[idx].tolist()  # S1_train_data.npz中的标签
                }
                # Append the record to the records list
                records.append(record_entry)
            # Write the records to a JSON file
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
            print(filename, 'record done')

        return accuracy_at_ks
    
    def finger(self, test_pos, train_file, filename):
        
        with open(train_file, 'r') as f:
            traindata = [json.loads(line) for line in f if line.strip()]
        
        with open('data/DB_molecular_feats.pkl', 'rb') as f:
            x = pickle.load(f, encoding='utf-8')
            mfeat = []
            for y in x['Morgan_Features']:
                mfeat.append(y)
            mfeat = np.array(mfeat)

        features = []
        train_labels = []

        for i in range(len(traindata)):
            entry = traindata[i]
            entry_id = entry['id']
            train_triplet = entry['test_triplet']
            h, t, r = train_triplet
            feature = np.concatenate([mfeat[h], mfeat[t]])
            features.append(feature)
            train_labels.append(r)

        train_labels = np.array(train_labels)
        features = np.array(features)
        features_tensor = torch.FloatTensor(features).cuda()
        features_tensor = features_tensor.reshape(len(traindata), -1)

        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        accuracy_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        relas = test_pos[:,2]

        for i in tqdm(range(len(test_pos))):
            h, t, r = test_pos[i]
            feature = np.concatenate([mfeat[h], mfeat[t]])
            test_feat = torch.FloatTensor(feature).cuda()
            test_feat = test_feat.reshape(1, -1)
            # cos 需要把similarity逆序
            # similarity = torch.nn.functional.cosine_similarity(test_feat, features_tensor)
            # 使用torch.nn.functional.pairwise_distance
            similarity = torch.nn.functional.pairwise_distance(test_feat, features_tensor, p=2)
            
            # # 交换顺序算一遍
            # feature_ex = np.concatenate([mfeat[t], mfeat[h]])
            # test_feat_ex = torch.FloatTensor(feature_ex).cuda()
            # test_feat_ex = test_feat_ex.reshape(1,-1)
            # similarity_ex = torch.nn.functional.pairwise_distance(test_feat_ex, features_tensor, p=2)
            # similarity = torch.min(torch.stack([similarity, similarity_ex]), dim=0).values

            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)

            for k in k_values:
                # 获取最相似的 k 个训练样本的索引
                k_nearest_indices = sorted_similarity[:k]
                # 获取最相似的 k 个训练样本的标签
                k_nearest_labels = train_labels[k_nearest_indices]
                # print(k, k_nearest_labels)
                # 统计标签的频率并选择出现频率最高的标签作为预测结果
                most_common_label = Counter(k_nearest_labels).most_common(1)[0][0]
                pred_at_ks[k].append(most_common_label)
                
            neighbors.append(sorted_similarity[:30])
            neighbors_labels.append(train_labels[sorted_similarity[:30]])
        
        # print(mfeat[100],mfeat[200].shape)

        label = relas.data.cpu().numpy()
        for k in k_values:
            pred = np.array(pred_at_ks[k])
            # print(len(pred), len(label))
            accuracy_at_ks[k] = np.sum(pred == label) / len(pred)

        # 记录在一个json文件里 
        records = []
        for idx in range(len(neighbors)):
            neighbors_indices = neighbors[idx]
            record_entry = {
                'id': idx,
                'K-neighbor': neighbors_indices.tolist(),  
                'K-neighbor_labels': neighbors_labels[idx].tolist()
            }
            # Append the record to the records list
            records.append(record_entry)
        # Write the records to a JSON file
        with open(filename, 'w') as f:
            json.dump(records, f, indent=4)
        print(filename, 'record done')

        return accuracy_at_ks


    def descript(self, test_pos, train_file, filename):
        
        with open(train_file, 'r') as f:
            traindata = [json.loads(line) for line in f if line.strip()]
        
        des_emb = np.load('DB_drug_emb.npz')
        des_emb = des_emb['emb']

        features = []
        train_labels = []

        for i in range(len(traindata)):
            entry = traindata[i]
            entry_id = entry['id']
            train_triplet = entry['test_triplet']
            h, t, r = train_triplet
            feature = np.concatenate([des_emb[h], des_emb[t]])
            features.append(feature)
            train_labels.append(r)

        train_labels = np.array(train_labels)
        features = np.array(features)
        features_tensor = torch.FloatTensor(features).cuda()
        features_tensor = features_tensor.reshape(len(traindata), -1)

        k_values = [1,5,10,20,30,50]
        pred_at_ks = {k: [] for k in k_values}
        accuracy_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []
        relas = test_pos[:,2]

        for i in tqdm(range(len(test_pos))):
            h, t, r = test_pos[i]
            feature = np.concatenate([des_emb[h], des_emb[t]])
            test_feat = torch.FloatTensor(feature).cuda()
            test_feat = test_feat.reshape(1, -1)
            # cos 需要把similarity逆序
            # similarity = torch.nn.functional.cosine_similarity(test_feat, features_tensor)
            # 使用torch.nn.functional.pairwise_distance
            similarity = torch.nn.functional.pairwise_distance(test_feat, features_tensor, p=2)
            
            # # 交换顺序算一遍  drugbank不交换好
            # feature_ex = np.concatenate([des_emb[t], des_emb[h]])
            # test_feat_ex = torch.FloatTensor(feature_ex).cuda()
            # test_feat_ex = test_feat_ex.reshape(1,-1)
            # similarity_ex = torch.nn.functional.pairwise_distance(test_feat_ex, features_tensor, p=2)
            # similarity = torch.min(torch.stack([similarity, similarity_ex]), dim=0).values

            similarity = similarity.detach().cpu().numpy()
            sorted_similarity = np.argsort(similarity)

            for k in k_values:
                # 获取最相似的 k 个训练样本的索引
                k_nearest_indices = sorted_similarity[:k]
                # 获取最相似的 k 个训练样本的标签
                k_nearest_labels = train_labels[k_nearest_indices]
                # print(k, k_nearest_labels)
                # 统计标签的频率并选择出现频率最高的标签作为预测结果
                most_common_label = Counter(k_nearest_labels).most_common(1)[0][0]
                pred_at_ks[k].append(most_common_label)
                
            neighbors.append(sorted_similarity[:30])
            neighbors_labels.append(train_labels[sorted_similarity[:30]])
        
        # print(des_emb[100], des_emb[200].shape)

        label = relas.data.cpu().numpy()
        for k in k_values:
            pred = np.array(pred_at_ks[k])
            # print(len(pred), len(label))
            accuracy_at_ks[k] = np.sum(pred == label) / len(pred)

        # 记录在一个json文件里 
        records = []
        for idx in range(len(neighbors)):
            neighbors_indices = neighbors[idx]
            record_entry = {
                'id': idx,
                'K-neighbor': neighbors_indices.tolist(),  
                'K-neighbor_labels': neighbors_labels[idx].tolist()
            }
            # Append the record to the records list
            records.append(record_entry)
        # Write the records to a JSON file
        with open(filename, 'w') as f:
            json.dump(records, f, indent=4)
        print(filename, 'record done')

        return accuracy_at_ks
        
    
    def evaluate(self, test_pos, test_neg, KG, record=0, filename=None, save_pre=0):
        heads, tails, relas = test_pos[:,0], test_pos[:,1], test_pos[:,2]
        batch_size = self.args.test_batch_size
        num_batch = len(heads) // batch_size + int(len(heads)%batch_size>0)
        
        # Initialize lists to collect data
        h_list = []
        t_list = []
        r_list = []
        ht_embed_list = []

        rela_probs = []
        self.model.eval()
        for i in range(num_batch):
            start = i * batch_size
            end = min((i+1)*batch_size, len(heads))
            batch_h = heads[start:end].cuda()
            batch_t = tails[start:end].cuda()
            ht_embed = self.model.enc_ht(batch_h, batch_t, KG)
            
            if save_pre:
                h_list.append(batch_h.cpu().numpy())
                t_list.append(batch_t.cpu().numpy())
                batch_r = relas[start:end]
                r_list.append(batch_r.cpu().numpy())
                ht_embed_list.append(ht_embed.detach().cpu().numpy())

            scores = self.model.enc_r(ht_embed)
            rela_scores = F.softmax(scores, dim=-1).data.cpu().numpy()

            rela_probs.append(rela_scores)
        rela_probs = np.concatenate(rela_probs)
        pred = np.argmax(rela_probs, axis=1)
        label = relas.data.cpu().numpy()
        accuracy = np.sum(pred == label) / len(pred)
        f1 = f1_score(label, pred, average='macro')
        kappa = cohen_kappa_score(label, pred)

        if record > 0:
            # 对于每个test case，记录其id、预测的前10个答案（包含relation id、概率）
            # 记录在一个json文件里 
            records = []
            for idx in range(len(rela_probs)):
                # Get the probability distribution for the current test case
                prob_dist = rela_probs[idx]
                # Get indices of relations sorted by probability in descending order
                sorted_indices = np.argsort(prob_dist)[::-1]
                # Get the top 10 relation IDs and their probabilities
                top_10_indices = sorted_indices[:10]
                top_10_probs = prob_dist[top_10_indices]
                # Create a list of dictionaries for the top 10 predictions
                predictions = []
                for rel_id, prob in zip(top_10_indices, top_10_probs):
                    predictions.append({
                        'relation_id': int(rel_id),
                        'probability': int(float(prob)*1000)/1000
                    })
                # Create a record for the current test case
                record_entry = {
                    'id': idx,
                    'predictions': predictions
                }
                # Append the record to the records list
                records.append(record_entry)
            # Write the records to a JSON file
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
        
        if save_pre:
            # Concatenate lists into numpy arrays
            h_array = np.concatenate(h_list)
            t_array = np.concatenate(t_list)
            r_array = np.concatenate(r_list)
            ht_embed_array = np.concatenate(ht_embed_list, axis=0)
            print(h_array.shape, ht_embed_array.shape) 
            
            # Save the data using numpy .npz format
            if 'S1' in self.args.dataset:
                dataset = 'S1'
            elif 'S2' in self.args.dataset:
                dataset = 'S2'
            else:
                dataset = 'S0'
            np.savez(f'{dataset}-test_data.npz', h=h_array, t=t_array, r=r_array, ht_embed=ht_embed_array)

        return f1, accuracy, kappa

    def test_single(self, triplet, KG):
        heads = triplet[0].unsqueeze(0)
        tails = triplet[1].unsqueeze(0)
        ht_embed = self.model.enc_ht(heads, tails, KG)
        scores = self.model.enc_r(ht_embed)
        rela_scores = F.softmax(scores, dim=-1).data.cpu().numpy()

        pred = np.argmax(rela_scores, axis=-1)
        return pred[0]


    def visualize(self, triplet, KG, head_batch=True):
        h, t, r = triplet[0], triplet[1], triplet[2:]
        paths, weights = self.model.visualize_forward(h.unsqueeze(0), t.unsqueeze(0), r.unsqueeze(0), KG, 15, head_batch)
        outputs = []
        rel_weights = [0] * (self.all_rel - self.eval_rel)
        rel_freq = [0] * self.all_rel
        for path, weight in zip(paths, weights):
            out_str = '%4f\t' % weight
            for i in range(len(path)):
                h, t, r = path[i]
                h_name = self.entity_vocab[h]
                t_name = self.entity_vocab[t]
                if r == 2*self.all_rel:
                    r_name = 'idd'
                    continue
                # else:
                    # r_mod = r % self.all_rel
                    # if r_mod >= self.eval_rel:
                        # r_name = self.relation_vocab[r_mod]
                    # else:
                        # r_name = str(r % self.all_rel)
                    # rel_freq[r_mod] += 1

                if r >= self.all_rel and r < 2*self.all_rel:
                    # r_name += "_inv"
                    r_name = self.relation_vocab[r-self.all_rel]
                    out_str += f'({t_name}, {r_name}, {h_name}),'
                else:
                    r_name = self.relation_vocab[r]
                    out_str += f'({h_name}, {r_name}, {t_name}),'

                if r >= self.eval_rel and r < self.all_rel:
                    rel_weights[r-self.eval_rel] += 1
                elif r >= self.all_rel+self.eval_rel and r < 2*self.all_rel:
                    rel_weights[r-self.eval_rel-self.all_rel] += 1

                # if i == 0:
                    # out_str += '< %s, %6s, %18s' % (h_name, r_name, t_name)   # 打印路径
                # else:
                    # out_str += ', %6s, %18s' % (r_name, t_name)
            out_str += '\n'
            outputs.append(out_str)
        return outputs, np.array(rel_weights), np.array(rel_freq)

    def KG_relation_weights(self, triplets, KG):
        heads, tails = triplets[:,0], triplets[:,1]
        batch_size = self.args.test_batch_size
        num_batch = len(heads) // batch_size + int(len(heads)%batch_size>0)
        rel_weights = [[] for i in range(self.args.length)]
        self.model.eval()
        for i in range(num_batch):
            start = i * batch_size
            end = min((i+1)*batch_size, len(heads))
            batch_h = heads[start:end].cuda()
            batch_t = tails[start:end].cuda()
            relations = self.model.get_attention_weights(batch_h, batch_t, KG)
            for l in range(self.args.length):
                rel_weights[l].append(relations[l])
       
        all_weights = 0
        for l in range(self.args.length):
            rel_weight = np.concatenate(rel_weights[l], axis=0) # [N, n_rel]
            rel_weight = np.mean(rel_weight, axis=0)    # n_rel
            kg_weight = rel_weight[self.eval_rel:self.all_rel]
            kg_weight += rel_weight[self.all_rel+self.eval_rel:2*self.all_rel]
            kg_weight /= 2
            all_weights += kg_weight
            print(l, list(np.round(kg_weight, 2)))
        print(list(np.round(all_weights/self.args.length, 2)))

    def save_model(self, out_str=''):
        torch.save(self.model.state_dict(), self.args.dataset+'_saved_model.pt')
        print(out_str, 'model saved')

    def load_model(self, model_name='_saved_model.pt'):
        self.model.load_state_dict(torch.load(model_name))
        print(model_name, 'model loaded')
