import argparse
import random
import sys
import numpy as np
from gensim.models import KeyedVectors as W2Vec
from perturbations_store import PerturbationsStorage

def balanced_sampling(lines, count):
    n = len(lines)
    sampling_order = []
    target_repeats = 1
    sampled_counts = {line: 0 for line in lines}

    while len(sampling_order) < count:
        remaining_samples = [line for line in lines if sampled_counts[line] < target_repeats]

        if not remaining_samples:
            target_repeats += 1
            remaining_samples = lines.copy()

        random.shuffle(remaining_samples)
        for line in remaining_samples:
            if len(sampling_order) < count:
                sampling_order.append(line)
                sampled_counts[line] += 1
            else:
                break

    return sampling_order

# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('-e', action="store", dest="embed")
parser.add_argument("-p", action="store", dest="prob")
parser.add_argument('-s', action="store", dest="seed", type=int, default=42)
parser.add_argument('-c', action="store", dest="count", type=int, required=True)
parser.add_argument('--perturbations-file', action="store", dest='perturbations_file')
parser.add_argument('--transformed-file', action="store", dest='transformed_file', default="transformed_words.txt")
parser.add_argument('--linked-file', action="store", dest='linked_file', default="linked_words.txt")

parsed_args = parser.parse_args(sys.argv[1:])
seed = parsed_args.seed
random.seed(seed)
np.random.seed(seed)

emb = parsed_args.embed
prob = float(parsed_args.prob)
count = parsed_args.count 

perturbations_file = PerturbationsStorage(parsed_args.perturbations_file)

transformed_lines_file = parsed_args.transformed_file
linked_lines_file = parsed_args.linked_file

open(transformed_lines_file, "w", encoding="utf-8").close()
open(linked_lines_file, "w", encoding="utf-8").close()

# Load Word2Vec model
model = W2Vec.load_word2vec_format(emb)

isOdd, isEven = False, False
topn = 20
mydict = {}

lines = [line.strip() for line in sys.stdin if line.strip()]

sampled_lines = balanced_sampling(lines, count)


for original_line in sampled_lines:
    transformed_line = []

    for word in original_line.split():
        transformed_word = []

        for char in word:
            if char not in mydict:
                similar = model.most_similar(char, topn=topn)
                if isOdd:
                    similar = [similar[iz] for iz in range(1, len(similar), 2)]
                elif isEven:
                    similar = [similar[iz] for iz in range(0, len(similar), 2)]
                words, probs = [x[0] for x in similar], np.array([x[1] for x in similar])
                probs /= np.sum(probs)
                mydict[char] = (words, probs)
            else:
                words, probs = mydict[char]

            transformed_char = np.random.choice(words, 1, replace=True, p=probs)[0] if random.random() < prob else char
            perturbations_file.add(char, transformed_char)
            transformed_word.append(transformed_char)

        transformed_line.append("".join(transformed_word))

    transformed_line_str = " ".join(transformed_line)

    
    with open(transformed_lines_file, "a", encoding="utf-8") as transformed_file:
        transformed_file.write(transformed_line_str + "\n")

  
    with open(linked_lines_file, "a", encoding="utf-8") as linked_file:
        linked_file.write(f"{original_line}||{transformed_line_str}\n")


perturbations_file.maybe_write()