using GerryChain
using TickTock
using Plots, Shapefile
using Printf
using JLD
using SparseArrays
using Random
using DataFrames
using CSV

### METADATA ###
SHAPEFILE_PATH     = "Shapefiles/WI/WI_dual_graph_2020.json"
POPULATION_COL     = "TOTPOP"
ENSEMBLE_FILENAME  = "../../Redistricting_via_Local_Fairness/audit_by_ensemble/Default_Ensembles/WI_ensemble"
NUM_DISTRICTS      = 8
BLUE_VOTES         = "PRES12D"
RED_VOTES          = "PRES12R"
NUM_MAPS           = 1000
COMPETITIVE_THRESHOLD = 0.535

# Specify parameters, defaulted to be the whole ensemble
START_MAP             = 1
END_MAP               = NUM_MAPS

# These are defaulted to be max # of districts; no need to modify if no performance issues arise
SAMPLE_SIZE           = NUM_DISTRICTS * (END_MAP - START_MAP + 1) # How many districts to sample before moving on
TERMINATION_THRESHOLD = NUM_DISTRICTS * (END_MAP - START_MAP + 1) # How many deviating group must be found before terminating a map

# Initialize graph
graph = BaseGraph(SHAPEFILE_PATH, POPULATION_COL)

println("Loading the ensemble from file...")
tick()
partitions = load(string(ENSEMBLE_FILENAME, ".jld"), "maps")
tock()
println("Successfully loaded the ensemble from file.")

# For every district in the current map, we want to decide which party wins, and the partisanship
function compute_district_colors(cur_map_index::Int)
    district_color = Array{String}(undef, NUM_DISTRICTS)
    district_partisanship = zeros(NUM_DISTRICTS)
    #total # of voters who voted democrat or republican
    red_blue_votes = 0
    for district_index = 1:NUM_DISTRICTS
        current_district = partitions[cur_map_index].dist_nodes[district_index]

        # Count the votes, and decide whether blue or red wins for that district
        democrat_votes = 0
        republic_votes = 0
        for precinct in current_district
            democrat_votes += graph.attributes[precinct][BLUE_VOTES]
            republic_votes += graph.attributes[precinct][RED_VOTES]
            red_blue_votes += graph.attributes[precinct][BLUE_VOTES] +  graph.attributes[precinct][RED_VOTES]
        end
        if (democrat_votes > republic_votes)
            district_color[district_index] = "Blue"
            district_partisanship[district_index] = democrat_votes / (democrat_votes + republic_votes)
        else
            district_color[district_index] = "Red"
            district_partisanship[district_index] = republic_votes / (democrat_votes + republic_votes)
        end
    end
    return (district_color, district_partisanship, red_blue_votes)
end

# Determine if the sampled district is a deviating group in the current map
function determine_labels_voters(cur_map_index::Int, sampled_map_index::Int, sampled_district_index::Int, district_color)
    sampled_district = partitions[sampled_map_index].dist_nodes[sampled_district_index]
    unhappy_democrats = 0
    unhappy_republicans = 0
    total_population = 0

    for precinct in sampled_district
        distinct_for_cur_precinct_in_cur_map = partitions[cur_map_index].assignments[precinct]
        if (district_color[distinct_for_cur_precinct_in_cur_map] == "Blue")
            unhappy_republicans += graph.attributes[precinct][RED_VOTES]
        else
            unhappy_democrats += graph.attributes[precinct][BLUE_VOTES]
        end
        total_population += (graph.attributes[precinct][RED_VOTES] + graph.attributes[precinct][BLUE_VOTES])
    end

    return (unhappy_democrats, unhappy_republicans, total_population)
end


# Determine if the sampled district is a deviating group in the current map, using the interpolation voter model
function determine_labels_precinct_interpolated(cur_map_index::Int, sampled_map_index::Int, sampled_district_index::Int, district_color)
    sampled_district = partitions[sampled_map_index].dist_nodes[sampled_district_index]

    unhappy_democrats = 0
    unhappy_republicans = 0
    total_population = 0

    for precinct in sampled_district
        distinct_for_cur_precinct_in_cur_map = partitions[cur_map_index].assignments[precinct]
        blue = graph.attributes[precinct][BLUE_VOTES]
        red = graph.attributes[precinct][RED_VOTES]
        pop = graph.attributes[precinct][POPULATION_COL]
        if (district_color[distinct_for_cur_precinct_in_cur_map] == "Blue")
            unhappy_republicans += floor((red * pop) / (blue + red))
        else
            unhappy_democrats += floor((blue * pop) / (blue + red))
        end
        total_population += pop
    end

    return (unhappy_democrats, unhappy_republicans, total_population)
end


# Create DataFrame for map-level statistics
map_data_table = DataFrame(
    "Map" => Int[],
    "# DGs" => Int[],
    "# Blue DGs" => Int[],
    "# Red DGs" => Int[],
    "# 55% DGs" => Int[],
    "# 60% DGs" => Int[],
    "# blue districts" => Int[],
    "# red districts" => Int[],
    "avg partisanship" => Float64[],
    "# competitive districts" => Int[],
    "# precincts in DGs" => Int[],
    "# unhappy precincts in DGs" => Int[],
    "% voters in DGs" => Float64[],
    "% voters unhappy in DGs" => Float64[],
    )

# Create Array of DataFrame for storing deviating group informations
dev_group_table = Array{Union{Nothing, DataFrame}}(nothing, NUM_MAPS)
tick()
for i = 1:NUM_MAPS
    dev_group_table[i] = DataFrame(
        "Map" => Int[],
        "District" => Int[],
        "Total_Pop" => Float64[],
        "Type" => String[],
        "Unhappy_pct" => Float16[]
        )
end

# Loop thru each map created
for cur_map_index = START_MAP:END_MAP

    district_color, district_partisanship, red_blue_votes = compute_district_colors(cur_map_index)

    println("Start auditing for map ", cur_map_index, " ...")
    # Construct the random sequence used for sampling
    total_num_of_districts = NUM_MAPS * NUM_DISTRICTS
    sampling_order = randperm(total_num_of_districts)

    # We want to check if the [cur_map_index]-th map has any deviating groups.
    cur_num_of_dev_group = 0
    cur_sample_index = 1

    # If we have yet to find [TERMINATION_THRESHOLD] deviating groups, continue sampling
    while (cur_num_of_dev_group < TERMINATION_THRESHOLD)

        sampled_map_index = sampling_order[cur_sample_index] ÷ NUM_DISTRICTS
        sampled_district_index = sampling_order[cur_sample_index] % NUM_DISTRICTS
        if (sampled_district_index == 0)
            sampled_district_index = NUM_DISTRICTS
        else
            sampled_map_index += 1
        end

        unhappy_democrats, unhappy_republicans, total_population = determine_labels_voters(cur_map_index, sampled_map_index, sampled_district_index, district_color)

        # If this is a deviating group, store info to dataframe
        if (unhappy_republicans > total_population / 2)
            cur_num_of_dev_group += 1
            push!(dev_group_table[cur_map_index], (sampled_map_index, sampled_district_index, total_population, "Red", unhappy_republicans / total_population))
        end
        if (unhappy_democrats > total_population / 2)
            cur_num_of_dev_group += 1
            push!(dev_group_table[cur_map_index], (sampled_map_index, sampled_district_index, total_population, "Blue", unhappy_democrats / total_population))
        end

        # If we have sampled [SAMPLE_SIZE] districts, then move on
        cur_sample_index += 1
        if (cur_sample_index == SAMPLE_SIZE)
            break
        end
    end

    sort!(dev_group_table[cur_map_index], [:Unhappy_pct])

    # Prepare statistics and push record into map-level dataframe
    num_blue_dists = 0
    num_red_dists = 0
    avg_partisanship = 0.0
    num_comp_dists = 0

    for district_index = 1:NUM_DISTRICTS
        if (district_color[district_index] == "Blue")
            num_blue_dists += 1
        else
            num_red_dists += 1
        end
        avg_partisanship += (district_partisanship[district_index] / NUM_DISTRICTS)
        if district_partisanship[district_index] <= COMPETITIVE_THRESHOLD
            num_comp_dists += 1
        end
    end

    set_of_precincts_in_dev_groups = Set(Int[])
    set_of_unhappy_precincts_in_dev_groups = Set(Int[])
    for dev_group in eachrow(dev_group_table[cur_map_index])
        this_dg = partitions[dev_group."Map"].dist_nodes[dev_group."District"]
        union!(set_of_precincts_in_dev_groups, this_dg)
        if (dev_group."Type" == "Blue")
            for precinct in this_dg
                if graph.attributes[precinct][BLUE_VOTES] > graph.attributes[precinct][RED_VOTES]
                    union!(set_of_unhappy_precincts_in_dev_groups, precinct)
                end
            end
        else
            for precinct in this_dg
                if graph.attributes[precinct][BLUE_VOTES] <= graph.attributes[precinct][RED_VOTES]
                    union!(set_of_unhappy_precincts_in_dev_groups, precinct)
                end
            end
        end
    end
    num_precincts_in_dev_groups = length(set_of_precincts_in_dev_groups)
    num_unhappy_precincts_in_dev_groups = length(set_of_unhappy_precincts_in_dev_groups)

    total_pop_in_dev_groups = 0
    total_unhappy_pop_in_dev_groups = 0
    for precinct in set_of_precincts_in_dev_groups
        total_pop_in_dev_groups += graph.attributes[precinct][BLUE_VOTES] + graph.attributes[precinct][RED_VOTES]
        if graph.attributes[precinct][BLUE_VOTES] > graph.attributes[precinct][RED_VOTES]
            total_unhappy_pop_in_dev_groups += graph.attributes[precinct][RED_VOTES]
        else
            total_unhappy_pop_in_dev_groups += graph.attributes[precinct][BLUE_VOTES]
        end
    end

    num_of_blue_dev_groups_in_cur_map = nrow(filter(row -> row."Type" .== "Blue", dev_group_table[cur_map_index]))
    num_of_red_dev_groups_in_cur_map = nrow(dev_group_table[cur_map_index]) - num_of_blue_dev_groups_in_cur_map
    num_of_55_up_dev_groups_in_cur_map = nrow(filter(row -> row."Unhappy_pct" .>= 0.55, dev_group_table[cur_map_index]))
    num_of_60_up_dev_groups_in_cur_map = nrow(filter(row -> row."Unhappy_pct" .>= 0.6, dev_group_table[cur_map_index]))

    percent_pop_in_dg = total_pop_in_dev_groups/red_blue_votes
    percent_unhappy_pop_in_dg = total_unhappy_pop_in_dev_groups/red_blue_votes

    push!(map_data_table,
            (
                cur_map_index,
                nrow(dev_group_table[cur_map_index]),
                num_of_blue_dev_groups_in_cur_map,
                num_of_red_dev_groups_in_cur_map,
                num_of_55_up_dev_groups_in_cur_map,
                num_of_60_up_dev_groups_in_cur_map,
                num_blue_dists, num_red_dists,
                avg_partisanship,
                num_comp_dists,
                num_precincts_in_dev_groups,
                num_unhappy_precincts_in_dev_groups,
                percent_pop_in_dg,
                percent_unhappy_pop_in_dg
            )
        )
end
tock()

show(map_data_table)
println("")
println("Starting writing dataframe to file...")
save(string(ENSEMBLE_FILENAME, "_auditing_summary.jld"), "data", map_data_table) 
println("Finished writing dataframe to file.")

println("Starting outputing dataframe to CSV file...")
CSV.write(string(ENSEMBLE_FILENAME, "_auditing_summary.csv"), map_data_table)
println("Finished outputing dataframe to CSV file.")

println("")
println("Starting writing dataframe to file...")
save(string(ENSEMBLE_FILENAME, "_deviating_group_info.jld"), "data", dev_group_table)
println("Finished writing dataframe to file.") 

println("Starting outputing dataframe to CSV file...")
CSV.write(string(ENSEMBLE_FILENAME, "_deviating_group_info.csv"), dev_group_table)
println("Finished outputing dataframe to CSV file.")
