import os
import copy
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
import scipy.stats as stats
from tqdm import tqdm
from vocabulary import WordVocabulary

""" 
This is used to generate data accorinding to the original dataset distribution, by replicating sent types
"""
CFG_DIR = "cfgs/"
QUES_FORM_DATA_DIR = ""
# QUES_FORM_DATA_DIR = "data_utils/cfg_gen_data"
DATA_DIR = "data_utils/"


def load_tense_inflection_data(
    split, include_identity=True, include_present=True, filename=None
):
    if filename is None:
        filename = f"{QUES_FORM_DATA_DIR}/tense.{split}"
    else:
        filename = f"{QUES_FORM_DATA_DIR}/{filename}.{split}"

    with open(filename) as f:
        lines = f.read().split("\n")
        if lines[-1] == "":
            lines = lines[:-1]

    in_sents, out_sents = [], []
    counts = {
        'PAST': 0,
        'PRESENT': 0
    }
    for line in lines:
        in_sent, out_sent = line.split("\t")

        if 'PRESENT' in line: counts['PRESENT'] += 1
        else: counts['PAST'] += 1
        if not include_identity and "PAST" in in_sent:
            continue
        if not include_present and "PRESENT" in in_sent:
            continue

        in_sent = " ".join(in_sent.split()[:-1])

        in_sents.append(in_sent)
        out_sents.append(out_sent)
    print(f"{split}: {counts}")

    return pd.DataFrame({"input": in_sents, "output": out_sents})


def token_seq_to_type_seq(token_seq, token2tag):
    return " ".join(
        [token2tag.get(token, token) for token in token_seq.split()]
    ).strip()

tense_map = {
    "giggled" : ["giggle", "giggles"],
    "smiled" : ["smile", "smiles"],
    "slept": ["sleep", "sleeps"],
    "swam": ["swim", "swims"],
    "waited": ["wait", "waits"],
    "moved":  ["move", "moves"],
    "changed": ["change", "changes"], 
    "read": ["read", "reads"],
    "ate": ["eat", "eats"],
    "entertained": ["entertain", "entertains"],
    "amused": ["amuse", "amuses"],
    "high_fived": ["high_five", "high_fives"],
    "applauded": ["applaud", "applauds"],
    "confused": ["confuse", "confuses"],
    "admired": ["admire", "admires"],
    "accepted": ["accept", "accepts"],
    "remembered": ["remember", "remembers"],
    "comforted": ["comfort", "comforts"],
}


def main(args):
    np.random.seed(args.seed)

    # Load Type to token Map
    type_to_token_map_present = defaultdict(list)
    type_to_token_map_past = defaultdict(list)
    token2tag_present = {}
    token2tag_past = {}
    with open(f"{CFG_DIR}/tag_token_map.txt") as f:
        for line in f:
            line = line.strip()
            if line == "":
                continue
            type_, token = line.split("\t")
            # present map: include plural and singular
            if type_[0] == "v":
                if "past" in type_:
                    type_to_token_map_past[type_].append(token)
                    token2tag_past[token] = type_
                else:
                    if "v_s" in type_ or "v_p" in type_:
                        type_to_token_map_present[type_].append(token)
                        token2tag_present[token] = type_
            else:
                type_to_token_map_present[type_].append(token)
                type_to_token_map_past[type_].append(token)
            
                token2tag_past[token] = type_
                token2tag_present[token] = type_


    token2tag_present["."] = ""
    token2tag_past["."] = ""

    # Load Tense Inflection Data
    train_present_from_data = load_tense_inflection_data(
        "train", include_identity=False, include_present=True
    )
    train_past_from_data = load_tense_inflection_data(
        "train", include_identity=True, include_present=False
    )
    
    # val_quest_form_data = load_question_formation_data(
    #     "val", include_identity=False, exclude_simple_interrogatives=True
    # )
    # tes       t_quest_form_data = load_question_formation_data("test", include_identity=False)

    # define the token_seq_to_type_seq function
    token_seq_to_type_seq_fn_past = lambda token_seq: token_seq_to_type_seq(
        token_seq, token2tag_past
    )
    token_seq_to_type_seq_fn_present = lambda token_seq: token_seq_to_type_seq(
        token_seq, token2tag_present
    )
    past_input_types = (
        train_past_from_data["input"]
        .apply(token_seq_to_type_seq_fn_past)
        .values.tolist()
    )
    past_input_types = np.array(past_input_types)

    present_input_types = ( # present input is stil PAST
        train_present_from_data["input"]
        .apply(token_seq_to_type_seq_fn_past)
        .values.tolist()
    )
    present_output_types = ( # This is the only present tense
        train_present_from_data["output"]
        .apply(token_seq_to_type_seq_fn_present)
        .values.tolist()
    )

    unique_past_input_types = np.unique(past_input_types)
    unique_present_input_types = np.unique(present_input_types)


    for i in range(len(unique_past_input_types)):
        assert "v_p_intrans" not in unique_past_input_types[i]
        assert "v_s_intrans" not in unique_past_input_types[i]
        assert "v_p_trans" not in unique_past_input_types[i]
        assert "v_s_trans" not in unique_past_input_types[i]
        
    for i in range(len(unique_present_input_types)):
        assert "v_p_intrans" not in unique_past_input_types[i]
        assert "v_s_intrans" not in unique_past_input_types[i]
        assert "v_p_trans" not in unique_past_input_types[i]
        assert "v_s_trans" not in unique_past_input_types[i]

    for i in range(len(present_output_types)):
        assert "past" not in present_output_types[i]

    print("total past types:", len(unique_past_input_types))

    # determine unique present types
    unique_present_types = {}
    for i in range(len(present_input_types)):
        if present_input_types[i] not in unique_present_types:
            unique_present_types[present_input_types[i]] = present_output_types[i]
    
    print("total quest types:", len(unique_present_types))

    ''' generate past '''
    if not args.present_only:
        # determine num of past samples to generate
        if not args.keep_present_ratio: # only add additional decls
            num_past_samples = args.num_samples - len(train_past_from_data)
        else: # scale up decl accordingly
            if args.past_ratio is None:
                past_ratio = len(train_past_from_data) / (len(train_present_from_data) + len(train_past_from_data))
            else: 
                past_ratio  = args.past_ratio # use defined
            num_past_samples = int(args.num_samples * past_ratio)
            print(f"iso-ratio generation: original past ratio: {past_ratio}, past samples: {num_past_samples}")
        
        ''' start filling past pairs'''
        past_pairs = []
        for i in range(len(train_past_from_data['input'])):
            past_pairs.append(f"{train_past_from_data['input'][i]} PAST\t{train_past_from_data['output'][i]}")
        
        print("used past pairs from training data: ", len(past_pairs))
        
        # add more past samples if needed
        if len(past_pairs) < num_past_samples:
            for inp in past_input_types:
                inp_types = inp.split()
                inp_tokens = copy.copy(inp_types)
                for i, type_ in enumerate(inp_types):
                    if type_ in type_to_token_map_past:
                        inp_tokens[i] = np.random.choice(type_to_token_map_past[type_])

                past_pairs.append(f"{' '.join(inp_tokens)} . PAST\t{' '.join(inp_tokens)} .")
        past_pairs = past_pairs[:num_past_samples]

    ''' generating present '''
    if args.present_only:
        num_present_samples = args.num_samples
    else:
        num_present_samples = args.num_samples - len(past_pairs)
    
    if args.present_type_ratio is not None:
        # split by types
        unique_input_present_types = {
            'V': [],
            'PREP': [],
            'RC': [],
        }
        unique_output_present_types = {
            'V': [],
            'PREP': [],
            'RC': [],
        }
        for i in range(len(present_input_types)):
            if "v" in present_input_types[i].split()[2]:
                unique_input_present_types['V'].append(present_input_types[i])
                unique_output_present_types['V'].append(present_output_types[i])
            elif present_input_types[i].split()[2] == 'rel' or "v" in present_input_types[i].split()[2]:
                unique_input_present_types['RC'].append(present_input_types[i])
                unique_output_present_types['RC'].append(present_output_types[i])
            elif present_input_types[i].split()[2] == 'prep':
                unique_input_present_types['PREP'].append(present_input_types[i])
                unique_output_present_types['PREP'].append(present_output_types[i])
            else: 
                raise ValueError("unknown type")
            
        # print("len of V: ", len(unique_input_present_types['V']))
        # print("len of PREP: ", len(unique_input_present_types['PREP']))
        # print("len of RC: ", len(unique_input_present_types['RC']))
          
        # determine the number each type to generate
        present_type_ratio = args.present_type_ratio.split(",")
        present_type_ratio = [float(ratio) for ratio in present_type_ratio]
        assert present_type_ratio[0] + present_type_ratio[1] + present_type_ratio[2] == 1.0
        num_v_samples = int(num_present_samples * present_type_ratio[0])
        num_prep_samples = int(num_present_samples * present_type_ratio[1])
        num_rc_samples= num_present_samples - num_prep_samples - num_v_samples
        
        # now construct the final list 
        final_present_samples = {
            "V_inp": [],
            "V_out": [],
            "PREP_inp": [],
            "PREP_out": [],
            "RC_inp": [],
            "RC_out": [],
        }
        for present_type in ["V", "PREP", "RC"]:
            if present_type == "V":
                num_type_sampples = num_v_samples
            elif present_type == "PREP":
                num_type_sampples = num_prep_samples
            else:
                num_type_sampples = num_rc_samples
            
            if num_type_sampples < len(unique_input_present_types[present_type]):
                final_present_samples[f'{present_type}_inp'] = unique_input_present_types[present_type][:num_type_sampples]
                final_present_samples[f'{present_type}_out'] = unique_output_present_types[present_type][:num_type_sampples]
            else:
                final_present_samples[f'{present_type}_inp']  = unique_input_present_types[present_type]
                final_present_samples[f'{present_type}_out']  = unique_output_present_types[present_type]
                while len(final_present_samples[f'{present_type}_inp']) < num_type_sampples:
                    final_present_samples[f'{present_type}_inp']  += unique_input_present_types[present_type]
                    final_present_samples[f'{present_type}_out']  += unique_output_present_types[present_type]
                # then trim
                final_present_samples[f'{present_type}_inp'] = final_present_samples[f'{present_type}_inp'][:num_type_sampples]
                final_present_samples[f'{present_type}_out'] = final_present_samples[f'{present_type}_out'][:num_type_sampples]

        # construct final list
        print("V samples: ", len(final_present_samples["V_inp"]))
        print("PREP samples: ", len(final_present_samples["PREP_inp"]))
        print("RC samples: ", len(final_present_samples["RC_inp"]))
        
        final_present_input_types = final_present_samples["V_inp"] + final_present_samples["PREP_inp"] + final_present_samples["RC_inp"]
        final_present_output_types = final_present_samples["V_out"] + final_present_samples["PREP_out"] + final_present_samples["RC_out"]
   
    present_pairs = []
    while len(present_pairs) < num_present_samples:
        for inp, out in zip(final_present_input_types, final_present_output_types):

            inp_types = inp.split()
            inp_tokens = copy.copy(inp_types)
            for i, type_ in enumerate(inp_types):
                if type_ in type_to_token_map_past:
                    inp_tokens[i] = np.random.choice(type_to_token_map_past[type_])
            # align with the present output
            out_tokens = copy.copy(inp_tokens)
            out_types= out.split()
            for i, type_ in enumerate(out_types):
                if "v_p" in type_:
                    out_tokens[i] = tense_map[out_tokens[i]][0]
                elif "v_s" in type_:
                    out_tokens[i] = tense_map[out_tokens[i]][1]

            present_pairs.append(f"{' '.join(inp_tokens)} . PRESENT\t{' '.join(out_tokens)} .")
    present_pairs = present_pairs[:num_present_samples]

    train_pairs = past_pairs + present_pairs 
    train_pairs = train_pairs[:args.num_samples]

    data_dir = f"{DATA_DIR}/cfg_gen_data/"
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    with open(f"{data_dir}/ti_{args.num_types}_types.train", "w") as f:
        f.write("\n".join(train_pairs))
    print("saved data:", f"{data_dir}/ti_{args.num_types}_types.train")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_types", type=int, default=-1)
    parser.add_argument("--num_samples", type=int, default=50000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--present_only", action="store_true")
    parser.add_argument("--keep_present_ratio", action="store_true", help='flag to control past/present ratio')
    parser.add_argument("--past_ratio", type=float, default=None)
    parser.add_argument("--present_type_ratio", type=str, default=None, help="[prep, rc]")
    args = parser.parse_args()
    main(args)

