import csv
import itertools
import numpy as np
import os


"""
    analysis_importance.py --> chosen_concept.py --> find_coalition.py --> phiS_chosen_concept.py
    --> main_board_background.py --> sideView.py
"""


def read_csv_file(filename):
    data = []
    with open(filename, 'r') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # Skip the first row
        next(csv_reader)  # Skip the second row
        next(csv_reader)  # Skip the third row
        for row in csv_reader:
            # Extract values for player_1 to player_10 columns
            row_values = [int(float(value)) for value in row[2:12]]
            data.append(row_values)
    return data

def generate_coalition():
    # define the array to store all possible subarrays
    array = []

    # generate all possible subarrays with 10 elements containing only 0's and 1's
    subarrays = list(itertools.product([0, 1], repeat=10))

    # iterate through the subarrays
    for subarray in subarrays:
        # count the number of 1's in the subarray
        count_ones = subarray.count(1)
        # check if the number of 1's is between 2 and 5
        if 2 <= count_ones <= 5:
            # add the subarray to the array
            array.append(list(subarray))

    # print the array of all possible subarrays
    return array


def belongs_to(arr1, arr2):
    # iterate through the elements of arr1
    for i, val in enumerate(arr1):
        # check if the element is 1 and the corresponding element in arr2 is 0
        if val == 1 and arr2[i] == 0:
            # if yes, arr1 does not belong to arr2
            return False
    # if all elements of arr1 have a corresponding 1 in arr2, arr1 belongs to arr2
    return True


def sub_array(arr):
    #print(arr)
    for i in range(len(arr)):
        if(arr[i]==1):
            arr[i]=0
            break
    return arr


# copy from arr1 to arr2
def copy(arr1,arr2):
    for i in range(len(arr1)):
        arr2[i]=arr1[i]


def equal_array(arr1,arr2):
    for i in range(len(arr2)):
        if(arr1[i]!=arr2[i]):
            return False
    return True


def print_array(final_array):
    for i in range(len(final_array)):
        print(final_array[i])


# Example usage
# filename = 'test/id-9A07EFB0-sample-0011.csv'

chosen_num = 30
def find_coalition(filename):
    sample_id = filename.split("/")[-5]
    helper = [2, 5, 12, 12]
    data = read_csv_file(filename)
    # print(data[0])
    array = generate_coalition()
    # print(len(array))
    count = np.zeros(len(array))
    count2 = np.zeros(len(array))
    for i in range(0, len(array)):
        for j in range(0, len(data)):
            if belongs_to(array[i], data[j]):
                count[i] += 1
        count2[i] = count[i] * helper[sum(array[i]) - 2]

    combined = list(zip(count2, array, count))
    sorted_combined = sorted(combined, reverse=True)
    sorted_array = [t[1] for t in sorted_combined]
    sorted_count2 = [t[0] for t in sorted_combined]
    sorted_count = [t[2] for t in sorted_combined]

    # final_array = []
    # index = 0
    # help1 = True
    # help2 = True
    # # print_array(sorted_array)
    # while (len(final_array) < chosen_num):
    #     next_array = sorted_array[index].copy()
    #     # print(next_array)
    #     final_array.append(next_array.copy())
    #     print_array(final_array)
    #     print("--------")
    #     if (sum(next_array) >= 3):
    #         print(1)
    #         # print(sub_array(next_array))
    #         sub_array1 = sub_array(next_array).copy()
    #         for i in range(len(final_array)):
    #             if (equal_array(sub_array1, final_array[i])):
    #                 help1 = False
    #         if (help1):
    #             final_array.append(sub_array1.copy())
    #             print_array(final_array)
    #             print("--------")
    #
    #         if (sum(sub_array1) >= 3):
    #             sub_array2 = sub_array(sub_array1).copy()
    #             for i in range(len(final_array)):
    #                 if (equal_array(sub_array2, final_array[i])):
    #                     help2 = False
    #             if (help2):
    #                 final_array.append(sub_array2.copy())
    #                 print_array(final_array)
    #                 print("--------")
    #     index += 1

    # print(filename)
    # print(sorted_count[20:30])
    # print(sorted_count[20:30])

    with open(os.path.join(save_dir, f"{sample_id}.csv"), mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['coalition', 'order', 'player_1', 'player_2', 'player_3', 'player_4', 'player_5',
                         'player_6', 'player_7', 'player_8', 'player_9', 'player_10'])
        for i in range(0, chosen_num):
            row_i = [""]
            row_i.append(sum(sorted_array[i]))
            for j in range(0, 10):
                row_i.append(str(sorted_array[i][j]))
            writer.writerow(row_i)

#find_coalition(filename)
salient_load_dir = 'analysis_andor_chosen_S'
files = os.listdir(salient_load_dir)

loss = "l1_for_6_10"
reward_way = "gt-log-odds-minus-mean"
qthres = 0.4
lr = 1e-6
weight = 5
trick = "pqa"
lr_way = "a_1"
qstd = "vN_vEmpty_mean"
param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-lr-way-{lr_way}-qthres-{qthres}-qstd-{qstd}-weight-{weight}"
threshold_tau = 0.15

save_dir = 'analysis_coalitions'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for file in files:
    find_coalition(os.path.join(salient_load_dir, file, reward_way, param,  f"threshold_{threshold_tau}", "salient_concepts.csv"))
