import random
import numpy as np

import string
from collections import Counter

import os
import sys
import gc
import json

from data_pkg.IDS_channel import IDS_alignment_channel, IDS_channel

from utils.data_functions import write_data_to_file
from utils.sys_functions import get_available_memory
from utils.helper_functions import create_folder
from utils.print_functions import print_list
from utils.wandb_utils import wandb_kwargs_via_cfg

import hydra
from omegaconf import DictConfig, OmegaConf
        
def nest_strings(input_str: str) -> str: 

    """
    This function takes a string of the form (input_str) 'ABC|DEF|GHI' and returns a string of the form (nested_str) 'ADGBEHCFI'.

    Args:
    input_str (str): The input string.

    Returns:
    nested_string (str): The nested string.
    """ 

    split_strings = input_str.split('|')
    nested_string = ''

    min_length = min(len(s) for s in split_strings)

    for i in range(min_length):
        for s in split_strings:
            nested_string += s[i]

    for s in split_strings:
        if len(s) > min_length:
            nested_string += s[min_length:]

    return nested_string

def unnest_strings(nested_str: str, num_segments: int) -> list:
    """
    This function takes a string of the form (nested_str) 'ADGBEHCFI' and returns a list of segments of the form ['ABC', 'DEF', 'GHI'].

    Args:
    nested_str (str): The nested string.
    num_segments (int): The number of segments.

    Returns:
    segments (list): The list of original segments.
    """
    
    segments = [''] * num_segments

    for i, char in enumerate(nested_str):
        segment_index = i % num_segments
        segments[segment_index] += char

    return segments

def generate_ground_truth_sequence(length): 

    """
    Generates a random DNA sequence of length 'length'.

    Args:
    length (int): The length of the sequence.

    Returns:
    sequence (str): The generated sequence.
    """

    sequence = ''.join(random.choice('ATGC') for _ in range(length))
    return sequence

def sample_sequences(file_name, n):
    with open(file_name, 'r') as f:
        lines = f.readlines()
    
    sampled_lines = random.sample(lines, n)
    sampled_sequences = [line.strip() for line in sampled_lines]
    return sampled_sequences


def data_generation(data_set_size, observation_size, length_ground_truth, channel_statistics, target_type, data_type):
    
        
        ground_truth_sequence_list     = []
        data_list                      = []    
        
        if data_type == 'ids_data': 
            for i in range(data_set_size):
                ground_truth_sequence = generate_ground_truth_sequence(length_ground_truth)
                ground_truth_sequence_list.append(ground_truth_sequence)

        else:
            raise ValueError('Data type not recognized!')


        for i, ground_truth_sequence in enumerate(ground_truth_sequence_list):
            if i % int(1e3) == 0 and i != 0:
                print(f'data generation: {i:.2e}')

            observation_sequence_list = []
            alignment_sequence_list   = []

                
            if target_type == 'CPRED': 
                #print('here')
                for j in range(observation_size):
                    observation_sequence = IDS_channel(ground_truth_sequence, channel_statistics)
                    observation_sequence_list.append(observation_sequence)
    
            elif 'MSA' in target_type or 'NESTED' in target_type:
                if not 'std' in target_type and not 'ext' in target_type:
                    print('target_type: ', target_type)
                    raise ValueError('data_generation.py: target type not fully specified!')
                observation_sequence_list, alignment_sequence_list = IDS_alignment_channel(ground_truth_sequence = ground_truth_sequence, 
                                            channel_statistics = channel_statistics, 
                                            observation_size = observation_size,
                                            target_type = target_type, print_flag = False)
                exists = any('I' in s for s in observation_sequence_list)
                if exists:
                    print('I in observation sequence list')
                    print_list(observation_sequence_list)
                    sys.exit()
            else:
                raise ValueError('Target type not recognized!')
        
            concatenated_observation_sequences = '|'.join(observation_sequence_list)
            
            if target_type == 'CPRED':
                data_example = concatenated_observation_sequences + ":" + ground_truth_sequence
                    
            elif 'MSA' in target_type: 
                concatenated_alignments = '|'.join(alignment_sequence_list)
                data_example = concatenated_observation_sequences + ":" + concatenated_alignments
                
            elif 'NESTED' in target_type:   
                concatenated_alignments = '|'.join(alignment_sequence_list)
                nested_alignment    = nest_strings(concatenated_alignments)
                data_example = concatenated_observation_sequences + ":" + nested_alignment
            else:
                raise ValueError('Target type not recognized!')  
            
            data_list.append(data_example)

            if i % int(1e5) == 0 and i != 0:
                print(f'data generation - batch {i:.2e}: finished')
                print('Available RAM (GB):', get_available_memory())
                gc.collect

        data_pairs = [['ground_truth',ground_truth_sequence_list]]
        data_pairs.append([target_type, data_list])

        return data_pairs

def test_data_generation(ground_truth_sequence, observation_size, channel_statistics, target_type, data_type):
    
    """
    obtains ground truth sequence 
    """

    observation_sequence_list = []
    alignment_sequence_list   = []

        
    if data_type == 'ids_data':
        if target_type == 'CPRED': 
            for j in range(observation_size):
                observation_sequence = IDS_channel(ground_truth_sequence, channel_statistics)
                observation_sequence_list.append(observation_sequence)

        elif 'MSA' in target_type or 'NESTED' in target_type:
            if not 'std' in target_type and not 'ext' in target_type:
                print('target_type: ', target_type)
                raise ValueError('data_generation.py: target type not fully specified!')
            observation_sequence_list, alignment_sequence_list = IDS_alignment_channel(ground_truth_sequence = ground_truth_sequence, 
                                        channel_statistics = channel_statistics, 
                                        observation_size = observation_size,
                                        target_type = target_type, print_flag = False)
            exists = any('I' in s for s in observation_sequence_list)
            if exists:
                print('I in observation sequence list')
                print_list(observation_sequence_list)
                sys.exit()
        else:
            raise ValueError('Target type not recognized!')
    
    concatenated_observation_sequences = '|'.join(observation_sequence_list)
    
    if target_type == 'CPRED':
        data_example = concatenated_observation_sequences + ":" + ground_truth_sequence
            
    elif 'MSA' in target_type: 
        concatenated_alignments = '|'.join(alignment_sequence_list)
        data_example = concatenated_observation_sequences + ":" + concatenated_alignments
        
    elif 'NESTED' in target_type:   
        concatenated_alignments = '|'.join(alignment_sequence_list)
        nested_alignment    = nest_strings(concatenated_alignments)
        data_example = concatenated_observation_sequences + ":" + nested_alignment
    else:
        raise ValueError('Target type not recognized!')  
    
    return data_example


@hydra.main(config_path="../hydra/data_config", config_name="data_config", version_base=None)
def generate_data_set(cfg: DictConfig) -> None:

    # region dir
    script_dir = os.path.dirname(__file__)
    print("script_dir: ", script_dir)
    n = 2  # replace with the number of levels you want to go upe
    dir_n_levels_up = script_dir
    for _ in range(n):
        dir_n_levels_up = os.path.dirname(dir_n_levels_up)

    #print("dir_n_levels_up: ", dir_n_levels_up)
    repo_path = dir_n_levels_up
    print("repo_path: ", repo_path)
    data_pkg_dir = os.path.join(repo_path,'src','data_pkg')
    print("data_pkg_dir: ", data_pkg_dir)

    config_dict = wandb_kwargs_via_cfg(cfg)

    seed_number = cfg.seed_number
    random.seed(seed_number)
    np.random.seed(seed_number)
    
    observation_size = cfg.observation_size
    ground_truth_length = cfg.ground_truth_length
    data_type = cfg.data_type
    test_size = cfg.data_set_size
    target_type = cfg.target_type
    
    substitution_probability_lb = cfg.substitution_probability_lb
    substitution_probability_ub = cfg.substitution_probability_ub

    insertion_probability_lb = cfg.insertion_probability_lb
    insertion_probability_ub = cfg.insertion_probability_ub

    deletion_probability_lb = cfg.deletion_probability_lb
    deletion_probability_ub = cfg.deletion_probability_ub
    
    data_list = []
    ground_truth_list = []
    reads_list = []
    separator = '==============================='

    for i in range(test_size):
        ground_truth_sequence = generate_ground_truth_sequence(ground_truth_length)
        ground_truth_list.append(ground_truth_sequence)
        #print(f'ground_truth_sequence: {ground_truth_sequence}')

        if i % 1000 == 0 and i != 0:
            gc.collect()
            print('Available RAM (GB):', get_available_memory())
            print(f'ground truth data generation: {i:.2e}')


    for index,ground_truth_sequence in enumerate(ground_truth_list):

        if index % 1000 == 0 and index != 0:
            gc.collect()
            print('Available RAM (GB):', get_available_memory())
            print(f'reads data generation: {index:.2e}')

        substitution_probability = random.uniform(substitution_probability_lb, substitution_probability_ub)
        insertion_probability = random.uniform(insertion_probability_lb, insertion_probability_ub)
        deletion_probability = random.uniform(deletion_probability_lb, deletion_probability_ub)

        channel_statistics = {'substitution_probability': substitution_probability,
                                'insertion_probability': insertion_probability,
                                'deletion_probability': deletion_probability}

        data_example = test_data_generation(ground_truth_sequence, observation_size, channel_statistics, target_type, data_type)
        data_list.append(data_example)
        reads_list +=  data_example.split(':')[0].split('|')
        reads_list.append(separator)

    max_len = max(len(s) for s in reads_list)
    print('max len reads:', max_len)
    config_dict['max_len_reads'] = max_len
    
    save_flag = cfg.save_flag
    if save_flag:
        folder_name = cfg.folder_name
        folder_path = os.path.join(repo_path,'data',folder_name)
        
        create_folder(folder_path)
        print("folder_path: ", folder_path)
        write_data_to_file(filepath = f'{folder_path}/{target_type}_data.txt', data = data_list)
        write_data_to_file(filepath = f'{folder_path}/ground_truth.txt', data = ground_truth_list)
        write_data_to_file(filepath = f'{folder_path}/reads.txt', data = reads_list)
           
        json_file_path = f'{folder_path}/data_generation_config.json'
        with open(json_file_path, 'w') as f:
                json.dump(config_dict, f, indent = 4)
        
        
if __name__ == "__main__":

    print('data_generation.py')
    generate_data_set()
    