import os
import subprocess
from dataclasses import dataclass
import lzma
import wandb
import argparse
import shutil 
import pandas as pd
import numpy as np

import data
import match
import evaluation
import embeddings
# from baselines import VecMap

os.environ['WANDB_IGNORE_GLOBS'] = 'lan1/*,lan2/*'
os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6

defaults = dict(
    # data
    lan1='enwikishuf',
    lan2='eswikishuf',
    eval='en-es',
    size1=30,
    # size2=20, skip2=10,
    symmetric=1,
    width=5,
    # vectorization fasttext sim_svd trunc word2vec
    vectorize='fasttext',
    dim=300,
    # tokenizer WordLevel, BPE
    tokentype='WordLevel',
    vocab_size=5000,
    limit_alphabet=100,
    min_frequency=5,
    supervision='unsupervised',
    label='none',
)

run = wandb.init(config=defaults, project='data efficiency')
base1 = os.path.join(wandb.run.dir, 'lan1')
base2 = os.path.join(wandb.run.dir, 'lan2')
os.makedirs(base1)
os.makedirs(base2)
cfg = wandb.config

def make_sized(lan, sizemb, pout, skipmb=0):
    corpus = data.get_data(lan)
    text = corpus.headmb(lan, skipmb+sizemb)
    with open(pout, 'wt', encoding='utf-8') as fout:
        fout.write(text[int(skipmb*1e6):])

p1 = os.path.join(base1, 'c.txt')
p2 = os.path.join(base2, 'c.txt')

make_sized(cfg.lan1, cfg.size1, p1)

size2 = cfg.size1 if cfg.symmetric == 1 else cfg.size2
skip2 = cfg.size1 if cfg.lan1 == cfg.lan2 else 0
make_sized(cfg.lan2, size2, p2, skipmb=skip2)

d1 = data.Corpus(p1, base1,
                 tokentype=cfg.tokentype, vocab_size=cfg.vocab_size, limit_alphabet=cfg.limit_alphabet, min_frequency=cfg.min_frequency,
                 vectorize=cfg.vectorize, width=cfg.width, dim=cfg.dim)
d2 = data.Corpus(p2, base2,
                 tokentype=cfg.tokentype, vocab_size=cfg.vocab_size, limit_alphabet=cfg.limit_alphabet, min_frequency=cfg.min_frequency,
                 vectorize=cfg.vectorize, width=cfg.width, dim=cfg.dim)

def get_evaldict():
    lan1s, lan2s = cfg.eval.split('-')
    eval = data.MUSEEval()
    dictpath = os.path.join(wandb.run.dir, 'eval_id.dict')
    with open(dictpath, 'wt', encoding='utf-8', errors='surrogateescape') as f:
        v1 = d1.tokenizer.get_vocab()
        v2 = d2.tokenizer.get_vocab()
        intersection = set(v1.keys()).intersection(v2.keys())
        print('vocab has overlap of length', len(intersection))
        for w in intersection:
            f.write(f'{w}\t{w}\n')
    # dictid = dictpath

    if lan1s != lan2s:
        dictpath = os.path.join(wandb.run.dir, 'eval_dict.dict')
        lanpath = eval.eval_path(f'{lan1s}-{lan2s}', type='full')
        shutil.copyfile(lanpath, dictpath)
    return dictpath

dictpath = get_evaldict()

@dataclass
class SearchArgs:
    stochastic_interval = 10
    stochastic_initial = 1
    stochastic_multiplier = 2
    threshold = 1e-4
    maxiter = 100
    eta = 1
    method = 'orthogonal' # or orthogonal or lstsq
    match = 'vecmap'
    csls = True
args = SearchArgs()

dumpdir = os.path.join(wandb.run.dir, 'dump')
os.makedirs(dumpdir, exist_ok=True)

def evalf(sim):
    # simf = match.most_diff_match(sim, k=3)
    f, stats = evaluation.report_sim(sim, d1.tokenizer, d2.tokenizer, dictpath)
    print(stats)

def dict_init_binary(tiebreak=1e-3):
    inds = evaluation.dict_to_inds(dictpath, d1.tokenizer, d2.tokenizer, full=False)
    sim = tiebreak * np.random.rand(d1.Co.shape[0], d2.Co.shape[0])
    for i in range(len(inds[0])):
        sim[inds[0][i], inds[1][i]] = 1
    sim[0, 0] = 1
    return sim

rows = []
def experiment(drop=20, dim=300, r1=99, r2=99):
    print('original dim', d1.Co.shape)
    def record(type, sim):
        print(type)
        # plt.figure()
        # plt.imshow(sims[type])
        simd = match.most_diff_match(sim, 10)
        # inds1, inds2, sim_greed = match.greedy_match(simd, iters=5)
        df, stats = evaluation.report_sim(simd, d1.tokenizer, d2.tokenizer, dictpath)
        info = stats
        info.update({'id': run.id, 'drop': drop, 'dim_p': dim, 'method_type': type})
        for k, v in cfg.items():
            if k in info: print(f'Warning: {k} already exist')
            info[k] = v
        
        rows.append(info)
        wandb.log({'table': wandb.Table(dataframe=pd.DataFrame.from_records(rows))})
        wandb.log({'basicinfo': info})
        print(info)
        df.to_csv(os.path.join(dumpdir, f'{type}-{drop}-{dim}.csv'))

    d1ft = d1.vecs[cfg.vectorize]
    d2ft = d2.vecs[cfg.vectorize]
    normproc = ['unit', 'center', 'unit']
    normproc1 = ['unit']
    embeddings.normalize(d1ft, normproc)
    embeddings.normalize(d2ft, normproc)
    _, _, sim = match.vecmap(d1ft, d2ft, args, evalf=evalf)
    record(f'vecmap-{cfg.vectorize}', sim)

    def f(Co):
        A1 = np.sqrt(Co)
        X = match.sim_vecs(A1, dim, alpha=1)
        embeddings.normalize(X, normproc)
        return X
    X, Z = f(d1.Co), f(d2.Co)
    _, _, sim =  match.vecmap(X, Z, args, evalf=evalf)
    record('vecmap-raw', sim)

    ###### coocmap ######
    def f(Co):
        X = np.sqrt(Co)
        embeddings.normalize(X, normproc)
        return X
    X, Z = f(d1.Co), f(d2.Co)
    _, _, simscoocmap = match.coocmapt(X, Z, args, normproc=normproc1, sim_init=None, evalf=evalf)
    record('coocmap', simscoocmap)

    def f(Co):
        X = np.sqrt(Co)
        embeddings.normalize(X, normproc)
        match.clip(X, r1=r1, r2=r2)
        return X
    X, Z = f(d1.Co), f(d2.Co)
    _, _, simscoocmap = match.coocmapt(X, Z, args, normproc=normproc1, sim_init=None, evalf=evalf)
    record('coocmap-clip', simscoocmap)

    dropinit = simscoocmap

    def f(Co):
        X = np.sqrt(Co)
        embeddings.normalize(X, normproc)
        X = match.svd_power(X, beta=1, drop=drop, dim=None)
        embeddings.normalize(X, normproc)
        match.clip(X, r1=r1, r2=r2)
        return X
    X, Z = f(d1.Co), f(d2.Co)
    _, _, simscoocmap = match.coocmapt(X, Z, args, normproc=normproc1, sim_init=dropinit, evalf=evalf) 
    record('coocmap-drop', simscoocmap)
    
    dictinit = dict_init_binary()
    def f(Co):
        X = np.sqrt(Co)
        embeddings.normalize(X, normproc)
        return X
    X, Z = f(d1.Co), f(d2.Co)
    _, _, simdictinit = match.coocmapt(X, Z, args, normproc=normproc1, sim_init=dictinit, evalf=evalf)
    record('dict-init', simdictinit)

# generate a simple grid enumeration
from itertools import product
drops = [20]
dims = [300]
grid_plan = product(drops, dims)

for drop, dim in grid_plan:
    if drop >= dim: continue
    experiment(drop, dim)

# method = VecMap(d1.vecpath, d2.vecpath, dictpath, wandb.run.dir, cfg)
# method.run()
# res = method.eval(dictpath)
# # write the predictions
# wandb.log({'accuracy': res['accuracy'], 'coverage': res['coverage']})