import numpy as np
import random
from utils.data_functions import filter_string
import sys

def generate_insertion_pattern(n):
    # Initialize an empty list to store the patterns
    patterns = []
    
    # Loop through the range of n to generate each pattern
    for i in range(n):
        # Generate a pattern for the current position
        pattern = '-' * i + 'I' + '-' * (n - i - 1)
        # Append the generated pattern to the list
        patterns.append(pattern)
    
    return patterns

def generate_alignment_pattern(sequences):
    
    #length = len(insertion_vector)
    n = len(sequences)

    alignment_pattern = ['' for _ in range(n)]

    t_vec = [[0] for _ in range(n)]
    i = 0
    
    while (any(len(s) > 0 for s in sequences)):

        column = [s[0] if s != '' else '-' for s in sequences]
        check_value = 'I' in column
        
        if check_value == False:
            for column_index, char in enumerate(column):
                alignment_pattern[column_index] = alignment_pattern[column_index] + char

            sequences = [s[1:] for s in sequences]
            
        elif check_value == True:
            check_column = [elem == 'I' for elem in column]
            sum_column = sum([1 for char in column if char == 'I'])

            insertion_pattern = generate_insertion_pattern(sum_column)
            temp = 0
            for column_index, column_check_value in enumerate(check_column):

                if column_check_value == True:
                    
                    alignment_pattern[column_index] =  alignment_pattern[column_index] + insertion_pattern[temp] 
                    sequences[column_index] = sequences[column_index][1:]
                    temp += 1
                
                elif column_check_value == False:
                    alignment_pattern[column_index] =  alignment_pattern[column_index] + '-' * sum_column

        i += 1
    return alignment_pattern

def generate_alignment(alignment_pattern):

    alphabet = ['A', 'C', 'G', 'T']

    for index, seq in enumerate(alignment_pattern):
        seq = list(seq)
        for i in range(len(seq)):
            if seq[i] == 'I':
                seq[i] = random.choice(alphabet)
            if seq[i] == 'D':
                seq[i] = '-'
        alignment_pattern[index] = ''.join(seq)

    return alignment_pattern

def replace_I(obs, alg):
    
    if len(alg) != len(obs):
        print('error: strings are not of equal length')
        raise SyntaxError
    length = len(obs)

    obs = list(obs)
    alg = list(alg)

    for i in range(length):
        if obs[i] == 'I':
            obs[i] = alg[i]

    return ''.join(obs)

def IDS_alignment_channel(ground_truth_sequence, channel_statistics, observation_size, target_type, print_flag = False):

    # E.G. 'std_MSA'
    # E.G. 'ext_MSA'
    # E.G. 'std_NESTED'
    # E.G. 'ext_NESTED'

    def ids_alignment(x, channel_statistics, target_type):

        y = []  # Output sequence
        alignment_seq = []  # Alignment sequence

        t = 0
        alphabet = ['A', 'C', 'G', 'T']
        length = len(x)

        pi = channel_statistics['insertion_probability']
        pd = channel_statistics['deletion_probability']
        ps = channel_statistics['substitution_probability']

        while (t < length):
            rd = np.random.rand()

            if (rd<pi):

                if 'std' in target_type:
                    y.append('I')
                    alignment_seq.append('I')

                elif 'ext' in target_type:
                    char = random.choice(alphabet)
                    y.append(char)
                    alignment_seq.append('I')  
                else:
                    print('error: target_type not defined')
                    raise SyntaxError
                
            elif (rd<(pi+pd)):
                #y.append('D')
                alignment_seq.append('D')
                t += 1
                
            elif (rd<(pi+pd+ps)):
                #sub_list = alphabet[alphabet != x[t]]
                sub_list = [letter for letter in alphabet if letter != x[t]]
                y_sub = random.choice(sub_list)
                y.append(y_sub)
                alignment_seq.append(x[t])
                t += 1

            else:
                y.append(x[t])
                alignment_seq.append(x[t])
                t += 1

        y = ''.join(y)
        alignment_seq = ''.join(alignment_seq)

        return y, alignment_seq

    observation_list = []
    alignment_list = []

    # create channel matrices
    for j in range(observation_size):     
        #y, alignment_seq = IDS(ground_truth_sequence, channel_statistics)
        y, alignment_seq = ids_alignment(ground_truth_sequence, channel_statistics, target_type)
        observation_list.append(y)
        alignment_list.append(alignment_seq)

    alignment_list = generate_alignment_pattern(alignment_list)
    
    if 'std' in target_type:
        alignment_list = generate_alignment(alignment_list)
    
        for index, (obs, alg) in enumerate(zip(observation_list, alignment_list)):

            alg = filter_string(alg)
            observation_list[index] = replace_I(obs, alg)

    return observation_list, alignment_list

def IDS_channel(x, channel_statistics):

        y = []  # Output sequence
        
        t = 0
        alphabet = ['A', 'C', 'G', 'T']
        length = len(x)

        insertion_list = []
        deletion_list = []
        substitution_list = []

        ids_print_flag = False

        pi = channel_statistics['insertion_probability']
        pd = channel_statistics['deletion_probability']
        ps = channel_statistics['substitution_probability']

        while (t < length):
            rd = np.random.rand()

            if (rd<pi):
                char = random.choice(alphabet)
                y.append(char)

                insertion_list.append(char)
                deletion_list.append('-')
                substitution_list.append('-')

            elif (rd<(pi+pd)):
                t += 1

                insertion_list.append('-')
                deletion_list.append('D')
                substitution_list.append('-')
                
            elif (rd<(pi+pd+ps)):
                #sub_list = alphabet[alphabet != x[t]]
                sub_list = [letter for letter in alphabet if letter != x[t]]
                y_sub = random.choice(sub_list)
                y.append(y_sub)
                t += 1

                insertion_list.append('-')
                deletion_list.append('-')
                substitution_list.append(y_sub)

            else:
                y.append(x[t])
                t += 1

                insertion_list.append('-')
                deletion_list.append('-')
                substitution_list.append('-')

        y = ''.join(y)

        if ids_print_flag:
            print(insertion_list)
            print(deletion_list)
            print(substitution_list)

        return y


if __name__ == '__main__':

    #random.seed(42)
    #np.random.seed(42)

    test_size = int(1e0)
    test_size = 1

    length_ground_truth = 10
    observation_size = 5
    print_flag = False
    channel_statistics = {'substitution_probability': 0.1, 'deletion_probability': 0.1, 'insertion_probability': 0.1}

    ham_arr = np.zeros(test_size)
    lev_arr = np.zeros(test_size)

    target_type = 'CPRED'

    if target_type == 'CPRED':

        observation_list = []
        ground_truth_sequence = ''.join(random.choices('ACTG', k=length_ground_truth))
        print(ground_truth_sequence)
        print('##################################################')

        for i in range(test_size):
            obs_seq = IDS_channel(ground_truth_sequence, channel_statistics)
            print(ground_truth_sequence)
            print(obs_seq)
            observation_list.append(obs_seq)    
        
        print('------------------------------------------------------------')
        print('------------------------------------------------------------')
        
    else:
        
        for i in range(test_size):
            ground_truth_sequence = ''.join(random.choices('ACTG', k=length_ground_truth))
            observation_list, alignment = IDS_alignment_channel(ground_truth_sequence = ground_truth_sequence, channel_statistics = channel_statistics,
                                               observation_size = observation_size,
                                                target_type =  target_type, print_flag = False)
