import os
import copy
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
import scipy.stats as stats
from tqdm import tqdm
from vocabulary import WordVocabulary


CFG_DIR = "cfgs/"
QUES_FORM_DATA_DIR = ""
# QUES_FORM_DATA_DIR = "data_utils/cfg_gen_data"
DATA_DIR = "data_utils/"


def load_question_formation_data(
    split, include_identity=True, include_quest=True, filename=None
):
    if filename is None:
        filename = f"{QUES_FORM_DATA_DIR}/question.{split}"
    else:
        filename = f"{QUES_FORM_DATA_DIR}/{filename}.{split}"

    with open(filename) as f:
        lines = f.read().split("\n")
        if lines[-1] == "":
            lines = lines[:-1]

    in_sents, out_sents = [], []
    counts = {
        'quest': 0,
        'decl': 0
    }
    for line in lines:
        in_sent, out_sent = line.split("\t")

        if 'quest' in line: counts['quest'] += 1
        else: counts['decl'] += 1
        if not include_identity and "decl" in in_sent:
            continue
        if not include_quest and "quest" in in_sent:
            continue

        in_sent = " ".join(in_sent.split()[:-1])

        in_sents.append(in_sent)
        out_sents.append(out_sent)
    print(f"{split}: {counts}")

    return pd.DataFrame({"input": in_sents, "output": out_sents})


def token_seq_to_type_seq(token_seq, token2tag):
    return " ".join(
        [token2tag.get(token, token) for token in token_seq.split()]
    ).strip()

def count_num_auxs_template(text):
    assert isinstance(text, str)
    return text.count("aux")

def count_num_auxs(text): 
    assert isinstance(text, str)
    return text.count("do")

def main(args):
    # assert args.num_types % 2 == 0

    np.random.seed(args.seed)

    # Load Type to token Map
    type_to_token_map = defaultdict(list)
    token2tag = {}
    with open(f"{CFG_DIR}/tag_token_map.txt") as f:
        for line in f:
            line = line.strip()
            if line == "":
                continue
            type_, token = line.split("\t")
            type_to_token_map[type_].append(token)
            if token not in token2tag: # special corner case of "read", assume it belongs to v_p_intrans for question_formation
                token2tag[token] = type_
    token2tag["."] = ""
    token2tag["?"] = ""

    train_quest_from_data = load_question_formation_data(
        "train", include_identity=False, include_quest=True,
    )
    train_decl_from_data = load_question_formation_data(
        "train", include_identity=True, include_quest=False
    )
    # val_quest_form_data = load_question_formation_data(
    #     "val", include_identity=False, exclude_simple_interrogatives=True
    # )
    # test_quest_form_data = load_question_formation_data("test", include_identity=False)

    token_seq_to_type_seq_fn = lambda token_seq: token_seq_to_type_seq(
        token_seq, token2tag
    )
    decl_input_types = (
        train_decl_from_data["input"]
        .apply(token_seq_to_type_seq_fn)
        .values.tolist()
    )
    decl_input_types = np.array(decl_input_types)

    quest_input_types = (
        train_quest_from_data['input']
        .apply(token_seq_to_type_seq_fn)
        .values.tolist()
    )
    quest_output_types = (
        train_quest_from_data['output']
        .apply(token_seq_to_type_seq_fn)
        .values.tolist()
    )
    
    for i in range(len(decl_input_types)):    
        if "past" in decl_input_types[i]:
            # assert "read" in train_quest_form_data_with_decl['input'][i]
            print(i, decl_input_types[i], train_decl_from_data['input'][i])
    
    unique_decl_input_types = np.unique(decl_input_types)
    unique_quest_input_types = np.unique(quest_input_types)

    print("total decl types:", len(unique_decl_input_types))

    # determine unique question types
    unique_quest_types = {}
    for i in range(len(quest_input_types)):
        if quest_input_types[i] not in unique_quest_types:
            unique_quest_types[quest_input_types[i]] = quest_output_types[i]
    
    print("total quest types:", len(unique_quest_types))

    ''' generating decls '''
    if not args.quest_only:

        # determine num of decl samples to generate
        if not args.keep_quest_ratio: # only add additional decls
            num_decl_samples = args.num_samples - len(train_quest_from_data)
        else: # scale up decl accordingly
            if args.decl_ratio is None:
                decl_ratio = len(train_decl_from_data) / (len(train_quest_from_data) + len(train_decl_from_data))
            else: 
                decl_ratio  = args.decl_ratio # use defined
            num_decl_samples = int(args.num_samples * decl_ratio)
            print(f"iso-ratio generation: original decl ratio: {decl_ratio}, decl samples: {num_decl_samples}")

        if args.decl_aux_count is not None: # perform complex filtering - by aux count for a single aux
            assert args.decl_aux_count_ratio is None
            filtered_unique_decl_input_types = []
            
            # for inp in unique_decl_input_types: # uniform sampling
            for inp in decl_input_types: # original samping
                aux_count = count_num_auxs_template(inp)
                if aux_count == args.decl_aux_count:
                    if args.hier_decl_only:
                        if inp not in unique_quest_input_types:  # only append hier decl
                            filtered_unique_decl_input_types.append(inp)
                    elif args.linear_decl_only: # only append linear decl
                        if inp in unique_quest_input_types:
                            filtered_unique_decl_input_types.append(inp)
                    else:
                        filtered_unique_decl_input_types.append(inp)
            
            unique_decl_input_types = filtered_unique_decl_input_types
            print("total decl types to use:", len(unique_decl_input_types))

        
        elif args.decl_aux_count_ratio is not None: # perform complex filtering - allowing for multiple auxs
            aux_count_ratio = args.decl_aux_count_ratio.split(",")
            aux_count_ratio = [float(ratio) for ratio in aux_count_ratio]
            assert len(aux_count_ratio) == 3
            assert sum(aux_count_ratio) == 1
            print("generate decl with aux count ratio:", aux_count_ratio)

            filtered_unique_decl_input_types = {
                1: [],
                2: [],
                3: [],
            }
            for inp in decl_input_types: # original sampling
                aux_count = count_num_auxs_template(inp)
                if aux_count in [1, 2, 3]:
                    if aux_count==2 and args.hier_decl_only:
                        if inp not in unique_quest_input_types:  # only append hier decl
                            filtered_unique_decl_input_types[aux_count].append(inp)
                    elif aux_count==2 and args.linear_decl_only:
                        if inp in unique_quest_input_types: # only append linear decl
                            filtered_unique_decl_input_types[aux_count].append(inp)
                    else:
                        filtered_unique_decl_input_types[aux_count].append(inp)
                

            # construct final list according to the correct ratio
            unique_decl_input_types_dict = {
                1: [],
                2: [],
                3: [],
            }
            for aux_count, aux_count_list in filtered_unique_decl_input_types.items():
                num_aux_count = int(num_decl_samples * aux_count_ratio[aux_count - 1])
                if num_aux_count <= len(aux_count_list): # truncate if needed
                    unique_decl_input_types_dict[aux_count].extend(aux_count_list[:num_aux_count])
                else:
                    while len(unique_decl_input_types_dict[aux_count]) < num_aux_count:
                        unique_decl_input_types_dict[aux_count].extend(aux_count_list)
                    
                    # truncate to correct length
                    unique_decl_input_types_dict[aux_count] = unique_decl_input_types_dict[aux_count][:num_aux_count]
            
            # flatten into a single list
            unique_decl_input_types = []
            for aux_count, aux_count_list in unique_decl_input_types_dict.items():
                unique_decl_input_types.extend(aux_count_list)
                print("    aux count:", aux_count, "total:", len(unique_decl_input_types_dict[aux_count]))
            
            # apply diversity control
            if args.num_types > 0 and args.num_types < len(unique_decl_input_types):
                print("applying diversity control...")
                # sort by idx of the second aux
                def aux_count_sort(x):
                    x = x.split()
                    x_aux_idx = [i for i, token in enumerate(x) if "aux" in token]
                    return x_aux_idx[1]
        
                # sort by aux count
                # unique_decl_input_types = sorted(unique_decl_input_types, key=lambda x: aux_count_sort(x)) #[::-1] for 281, and none for 280
                # or randon shuffle
                np.random.shuffle(unique_decl_input_types)
                
                # apply diversity control - all aux counts
                # unique_decl_count = 0
                # unique_decl_appeared = set()
                # for inp in unique_decl_input_types:
                #     if inp not in unique_decl_appeared:
                #         unique_decl_appeared.add(inp)
                #         unique_decl_count += 1
                #     if unique_decl_count == args.num_types:
                #         break
                
                # apply diversity control - only 3 aux counts
                unique_decl_count = 0
                unique_decl_appeared = set()
                for inp in unique_decl_input_types:
                    aux_count = count_num_auxs_template(inp)
                    if aux_count == 3:
                        if inp not in unique_decl_appeared:
                            unique_decl_appeared.add(inp)
                            unique_decl_count += 1
                        if unique_decl_count == args.num_types:
                            break

                for inp in unique_decl_input_types:
                    aux_count = count_num_auxs_template(inp)
                    if aux_count in [1, 2]:
                        if inp not in unique_decl_appeared:
                            unique_decl_appeared.add(inp)
                

                # apply filtering but keep the original data ratio
                final_unique_decl_input_types = []
                for inp in unique_decl_input_types:
                    if inp in unique_decl_appeared:
                        final_unique_decl_input_types.append(inp)
                unique_decl_input_types = final_unique_decl_input_types

                print("total decl types to use:", len(set(unique_decl_appeared)))

        
        ''' start filling decl pairs '''
        decl_pairs = []

        if args.decl_aux_count_ratio is None: # only fill in original decl if no ratio control (o/w too hard)
            # directly copying decls from training data - if needed
            for i in range(len(train_decl_from_data['input'])):
                aux_count = count_num_auxs(train_decl_from_data['input'][i])
                if aux_count == args.decl_aux_count:
                    if args.hier_decl_only:
                        inp = decl_input_types[i]
                        if inp not in unique_quest_input_types:  # only append hier decl
                            decl_pairs.append(f"{train_decl_from_data['input'][i]} decl\t{train_decl_from_data['output'][i]}")
                    elif args.linear_decl_only:
                        inp = decl_input_types[i]
                        if inp in unique_quest_input_types:
                            decl_pairs.append(f"{train_decl_from_data['input'][i]} decl\t{train_decl_from_data['output'][i]}")
                    else:
                        decl_pairs.append(f"{train_decl_from_data['input'][i]} decl\t{train_decl_from_data['output'][i]}")
            
        print("used decl pairs from training data: ", len(decl_pairs))
        
        # additional generation - if needed
        # determine diversity (i.e., num of sent types)
        if args.decl_aux_count_ratio is None: # enforce number of types for single aux case, ratio num_types is already handled before
            if args.num_types > 0 and args.num_types < len(unique_decl_input_types):
                final_decl_types_list = unique_decl_input_types[: args.num_types]
            else:
                final_decl_types_list = unique_decl_input_types
                args.num_types = len(unique_decl_input_types)
        else:
            final_decl_types_list = unique_decl_input_types
            # args.num_types = len(unique_decl_input_types)

        while len(decl_pairs) < num_decl_samples:
            for idx in range(0, len(final_decl_types_list)):
                inp = final_decl_types_list[idx]

                inp_types = inp.split()
                inp_tokens = copy.copy(inp_types)
                for i, type_ in enumerate(inp_types):
                    if type_ in type_to_token_map:
                        inp_tokens[i] = np.random.choice(type_to_token_map[type_])
                decl_pairs.append(f"{' '.join(inp_tokens)} . decl\t{' '.join(inp_tokens)} .")
        
        # ensure correct length
        decl_pairs = decl_pairs[:num_decl_samples]
        

    ''' generating quests '''
    quest_pairs = []
    # re-generate questions with uniform distribution / original distribution
    # determine num of quest samples to generate
    if args.quest_only:
        num_quest_samples =  args.num_samples # just use num_samples
    else:
        num_quest_samples = args.num_samples - len(decl_pairs)
    
    print("question samples:", num_quest_samples)
    
    # directly copying question from training data - if needed
    for i in range(len(train_quest_from_data['input'])):
        quest_pairs.append(f"{train_quest_from_data['input'][i]} quest\t{train_quest_from_data['output'][i]}")
    quest_pairs = quest_pairs[:num_quest_samples] # ensure correct length
        
    print("total questions from data: ", len(quest_pairs))
    
    # additional quest generation
    while len(quest_pairs) < num_quest_samples:
        # for inp, out in unique_quest_types.items(): # this is for uniform sampling
        for inp, out in zip(quest_input_types, quest_output_types): # this is for original distribution
            inp_types = inp.split()
            inp_tokens = copy.copy(inp_types)
            for i, type_ in enumerate(inp_types):
                if type_ in type_to_token_map:
                    inp_tokens[i] = np.random.choice(type_to_token_map[type_])
            # Align output to input so same substitutions can be used
            out_types = out.split()
            out_id_to_inp_id = [0 for _ in range(len(out_types))]
            # First token is always the auxiliary, it will be aligned to the first token that differs between the input and output[1:]
            for idx, (inp_type, out_type) in enumerate(
                zip(inp_types, out_types[1:])
            ):
                if inp_type != out_type:
                    out_id_to_inp_id[0] = idx
                    break
                else:
                    out_id_to_inp_id[idx + 1] = idx
            # Align the rest of the tokens
            for idx in range(out_id_to_inp_id[0] + 1, len(out_types)):
                out_id_to_inp_id[idx] = idx
            out_tokens = [
                inp_tokens[out_id_to_inp_id[idx]] for idx in range(len(out_types))
            ]
            quest_pairs.append(f"{' '.join(inp_tokens)} . quest\t{' '.join(out_tokens)} ?")
    quest_pairs = quest_pairs[:num_quest_samples] # ensure correct length
            

    if args.quest_only:
        sentence_pairs = quest_pairs
    else:
        sentence_pairs = quest_pairs + decl_pairs
        # sentence_pairs = np.random.permutation(sentence_pairs)

    # check vocab
    in_vocab = WordVocabulary(sentence_pairs, split_punctuation=False)
    print("total vocab:", len(in_vocab))

    # finally, ensure correct sample size
    train_pairs = sentence_pairs[:args.num_samples]
    

    data_dir = f"{DATA_DIR}/cfg_gen_data/"
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    with open(f"{data_dir}/cfg_{args.num_types}_types.train", "w") as f:
        f.write("\n".join(train_pairs))
    # with open(f"{data_dir}/cfg_{args.num_types}_types.val", "w") as f:
    #     f.write("\n".join(test_pairs))
    print("saved data:", f"{data_dir}/cfg_{args.num_types}_types.train")

    # test_pairs = []
    # for _, row in test_quest_form_data.iterrows():
    #     test_pairs.append(
    #         row.input.strip()
    #         + " quest\t"
    #         + row.output.strip()
    #     )

    # with open(f"{data_dir}/cfg_{args.num_types}_types.test", "w") as f:
    #     f.write("\n".join(test_pairs))


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--num_types", type=int, default=-1)
    parser.add_argument("--num_samples", type=int, default=50000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--quest_only", action="store_true")
    parser.add_argument("--decl_aux_count", type=int, default=None)
    parser.add_argument("--keep_quest_ratio", action="store_true")
    parser.add_argument("--decl_ratio", type=float, default=None)
    parser.add_argument("--decl_aux_count_ratio", type=str, default=None)
    parser.add_argument("--hier_decl_only", action="store_true")
    parser.add_argument("--linear_decl_only", action="store_true")
    args = parser.parse_args()
    main(args)
