import torch
import torch.nn.functional as F
import numpy as np
import utils
import time
from tqdm import tqdm
from utils import *
from colors import *

# word_effects = calculate_word_effects(model, data)
# np.savez_compressed(model.local("word_effects"), effects=word_effects.detach().cpu().numpy())


def analogy_dataset():
    with open("analogy_datasets/google-questions-words.txt","r",encoding="utf-8") as f:
        dataset = {}
        category = ""
        for line in f:
            if ":" in line:
                category = line.strip().strip(": ").lower()
                dataset[category] = []
            else:
                s = line.strip().lower().split(" ")
                dataset[category].append(s)
    return dataset

def evaluate_analogy(model, data, multiplicative=False):
    """
    - category
        - embedding
          - method
            - top1 
            ...
            - accuracy_top1
            ...
        - total
    - total 
        - embed
          - method
            - top1
            ...
            - accuracy_top 
        - total
    """
    ana = analogy_dataset()
    analogy_statistics = {}
        
    for category, category_data in ana.items():
        print(f"evaluating analogy category {category}")
        stat = model.evaluate_analogy_category(category_data, data)
        analogy_statistics[category] = stat
        print(stat)

    final_stat = {}

    methods = list(analogy_statistics["capital-common-countries"]["embedding0"]["exclude_original"].keys())
    register(final_stat, "total", 0)
    for i in range(len(model.embeddings)):
        for exclusion in ["exclude_original"]: # "include_original", 
            for method in methods:
                register(final_stat, "embedding"+str(i), exclusion, method, "top1", 0)
                register(final_stat, "embedding"+str(i), exclusion, method, "top5", 0)
                register(final_stat, "embedding"+str(i), exclusion, method, "top10", 0)
    
    print(final_stat)
    for i in range(len(model.embeddings)):
        for exclusion in ["exclude_original"]: # "include_original", 
            for method in methods:
                for category_stat in analogy_statistics.values():
                    final_stat["total"]                 += category_stat["total"]
                    final_stat["embedding"+str(i)][exclusion][method]["top1"]  += category_stat["embedding"+str(i)][exclusion][method]["top1"]
                    final_stat["embedding"+str(i)][exclusion][method]["top5"]  += category_stat["embedding"+str(i)][exclusion][method]["top5"]
                    final_stat["embedding"+str(i)][exclusion][method]["top10"] += category_stat["embedding"+str(i)][exclusion][method]["top10"]
                final_stat["embedding"+str(i)][exclusion][method]["accuracy_top1"]  = category_stat["embedding"+str(i)][exclusion][method]["top1"]  / final_stat["total"]
                final_stat["embedding"+str(i)][exclusion][method]["accuracy_top5"]  = category_stat["embedding"+str(i)][exclusion][method]["top5"]  / final_stat["total"]
                final_stat["embedding"+str(i)][exclusion][method]["accuracy_top10"] = category_stat["embedding"+str(i)][exclusion][method]["top10"] / final_stat["total"]

    result = { "final" : final_stat, **analogy_statistics }
    print(result)
    return result


