import sys
import logging
import argparse
import csv
import numpy as np
# from .utils import get_log_filename

#logging.basicConfig(filename=get_log_filename('detransliterator.csv_tool'), level=logging.DEBUG)

def main_parse_command_line_args():
    parser = argparse.ArgumentParser("CSV File Manipulation Tools")
    parser.add_argument("option", choices=[
        "extract-column", 
        "insert-column", 
        "sample-csv-file", 
        "count-records", 
        "sort",
        "filter",
        "affix-column",
        "concatenate"
    ])

    parser.add_argument("--skip-lines", type=int, default=0)
    parser.add_argument("--column-ix", type=int)
    parser.add_argument("--input-file")
    parser.add_argument("--output-file")
    parser.add_argument("--sample-count", type=int)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--fieldnames", nargs="*")
    
    parser.add_argument("--csv-formatting-params", nargs="*", default=[])
    # https://docs.python.org/3/library/csv.html#csv-fmt-params

    # args for sort
    parser.add_argument("--sort-descending", default=False, action="store_true")
    parser.add_argument("--sort-numbers", default=False, action="store_true")
    parser.add_argument("--keep-top-k", type=int, default=-1)
    
    
    parser.add_argument("--sort-field-name")    
    
    # args for filter
    parser.add_argument("--primary-key-file")
    parser.add_argument("--primary-key-column")

    # args for affix
    parser.add_argument("--column-name")
    parser.add_argument("--prefix-column")
    parser.add_argument("--suffix-column")
    parser.add_argument("--affix-separator")
    

    # args for insert column
    parser.add_argument("--new-column-name")
    parser.add_argument("--new-column-value")

    # args for concatenate
    parser.add_argument("--input-files", nargs="+")
    

    return parser.parse_args()


def parse_csv_formatting_params(args):
    if len(args.csv_formatting_params) %2 != 0:
        raise ValueError(
            "--csv-formatting-params must be specified in pairs: "
            "parameter1-name parameter1-value parameter2-name parameter2-value ..."
        )
    csv_reader_kwargs = {}
    i=0
    while i<len(args.csv_formatting_params):
        key = args.csv_formatting_params[i]
        i+=1
        value = args.csv_formatting_params[i]
        i+=1
        csv_reader_kwargs[key] = value

    if "delimiter" in csv_reader_kwargs:
        if csv_reader_kwargs["delimiter"] == 'tab':
            csv_reader_kwargs["delimiter"] = '\t'
            
    return csv_reader_kwargs


def main_extract_csv_column(args):
    if args.column_ix is None:
        raise ValueError("--column-ix required for extraction")
    csv_reader_kwargs = parse_csv_formatting_params(args)
    reader = csv.reader(sys.stdin, **csv_reader_kwargs)
    row_ix = -1
    for row in reader:
        row_ix += 1
        if row_ix < args.skip_lines:
            continue
        cell = row[args.column_ix]
        print(cell)

def main_insert_csv_column(args):
    with open(args.input_file) as f:
        input_reader = csv.DictReader(f)
        if args.new_column_name in input_reader.fieldnames:
            raise ValueError(f"The column {args.new_column_name} already exists in the source file")
        input_rows = [r for r in input_reader]
        
        with open (args.output_file, "w") as of:
            output_writer = csv.DictWriter(of, fieldnames=input_reader.fieldnames + [args.new_column_name])
            output_writer.writeheader()
            for output_row in input_rows:
                output_row[args.new_column_name] = args.new_column_value
                output_writer.writerow(output_row)


def main_affix_csv_column(args):
    with open(args.input_file) as f:
        input_reader = csv.DictReader(f)
        if args.column_name not in input_reader.fieldnames:
            raise ValueError(f"The column {args.column_name} doesn't exist in the source file")
        input_rows = [r for r in input_reader]
        
        with open (args.output_file, "w") as of:
            output_writer = csv.DictWriter(of, fieldnames=input_reader.fieldnames)
            output_writer.writeheader()
            for output_row in input_rows:
                col_value = output_row[args.column_name]
                if args.prefix_column:
                    col_value = f"{output_row[args.prefix_column]}{args.affix_separator}{col_value}"
                if args.suffix_column:
                    col_value = f"{col_value}{args.affix_separator}{output_row[args.suffix_column]}"
                output_row[args.column_name] = col_value
                output_writer.writerow(output_row)



def main_sample_csv_file(args):
    rs = np.random.RandomState(seed=args.seed)
    with open(args.input_file) as f:
        rows = [r for r in csv.DictReader(f)]
        sampled_indices = rs.randint(low=0, high=len(rows), size=args.sample_count)
        with open (args.output_file, "w") as of:
            output_writer = csv.DictWriter(of, fieldnames=args.fieldnames)
            for ix in sampled_indices:
                output_row = {}
                for fieldname in args.fieldnames:
                    output_row[fieldname] = rows[ix][fieldname]
                output_writer.writerow(output_row)


def main_sort_csv_file(args):
    with open(args.input_file) as f:
        input_reader = csv.DictReader(f)
        rows = [r for r in input_reader]
        if args.sort_numbers:
            sorted_indices = np.argsort([float(r[args.sort_field_name]) for r in rows])
        else:
            sorted_indices = np.argsort([r[args.sort_field_name] for r in rows])

        if(args.sort_descending):
            sorted_indices = sorted_indices[::-1]

        sorted_rows = [rows[ix] for ix in sorted_indices]
        with open (args.output_file, "w") as of:
            output_writer = csv.DictWriter(of, fieldnames=input_reader.fieldnames)
            output_writer.writeheader()
            for row_ix, row in enumerate(sorted_rows):
                if args.keep_top_k >0 and row_ix > args.keep_top_k - 1:
                    break
                output_writer.writerow(row)


def main_concatenate_csv_files(args):
    all_rows = []
    field_names = None
    for file_path in args.input_files:
        with open(file_path) as f:
            input_reader = csv.DictReader(f)
            if field_names:
                if set(field_names) != set(input_reader.fieldnames):
                    raise ValueError(f"{file_path} doesn't have the same set of column as another file to be concatenated")
            else:
                field_names = input_reader.fieldnames
            all_rows.extend([r for r in input_reader])
    
    with open (args.output_file, "w") as of:
        output_writer = csv.DictWriter(of, fieldnames=field_names)
        output_writer.writeheader()
        for row_ix, row in enumerate(all_rows):
            output_writer.writerow(row)


def main_filter_csv_file(args):
    row_index = {}
    output_fieldnames = None
    with open(args.input_file) as input_f:
        data_input_reader = csv.DictReader(input_f)
        output_fieldnames = data_input_reader.fieldnames
        for data_row in data_input_reader:
            row_key = data_row[args.foreign_key_column]
            if row_key in data_row:
                warning_msg = f"{args.foreign_key_column}:{row_key} is not unique in {args.input_file.stem}"
                logging.warning(warning_msg)
                print(warning_msg)
            row_index[row_key] = data_row

    with open(args.primary_key_file) as pk_f, open(args.output_file, "w") as output_f:
        pk_reader = csv.DictReader(pk_f)
        output_writer = csv.DictWriter(output_f, fieldnames=output_fieldnames)
        output_writer.writeheader()
        for pk_row in pk_reader:
            output_writer.writerow(row_index[pk_row[args.primary_key_column]])


def main_count_records(args):
    i=0
    with open(args.input_file) as f:
        for r in csv.DictReader(f):
            i+=1
    print(f"{i} records")



if __name__ == '__main__':
    args = main_parse_command_line_args()
    main_function_map = {
        'extract-column': main_extract_csv_column,
        'insert-column': main_insert_csv_column,
        'affix-column': main_affix_csv_column,
        'sample-csv-file': main_sample_csv_file,
        'count-records': main_count_records,
        'sort': main_sort_csv_file,
        'filter': main_filter_csv_file,
        'concatenate': main_concatenate_csv_files
    }
    if args.option in main_function_map:
        main_function_map[args.option](args)
    else:
        raise ValueError(f"Unknown command: {args.option}")