# region imports

# evaluation package
from eval_pkg.distance_functions import hamming_distance
from eval_pkg.loader import model_loader, data_loader

# helper package
from utils.data_functions import filter_string
from utils.helper_functions import create_folder, append_fasta_seqs, get_repo_path
from utils.print_functions import print_list
from utils.wandb_utils import wandb_kwargs_via_cfg

# standard libaries
import numpy as np
import Levenshtein

import random
import wandb
import json
import sys
import os
import json
from typing import Any

from collections import defaultdict
import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf
# endregion

random.seed(23)
np.random.seed(23)

@hydra.main(config_path="../hydra/inference_config", config_name="inference_config", version_base=None)
def trace_reconstruction(cfg: DictConfig) -> None:

    config_dict = wandb_kwargs_via_cfg(cfg)

    if hasattr(cfg, 'additional_tags'):
        additional_tags = list(cfg.additional_tags)
        print('additional_tags: ', additional_tags)
        print(type(additional_tags))
        config_dict['additional_tags'] = additional_tags
    else:
        additional_tags = []
  
    print(config_dict)
    
    # region general
    # ------------------------------------ DIR ------------------------------------'
    script_dir       = os.path.dirname(__file__)
    print("script_dir: ", script_dir)

    repo_path = get_repo_path(script_dir,2)
    
    data_pkg_dir = os.path.join(repo_path,'data_pkg')
    print("data_pkg_dir: ", data_pkg_dir)
    
    # ------------------------------------GENERAL------------------------------------
    now_str = cfg.general.start_time
    print('now_str: ', now_str)
    
    additional_postprocessing = cfg.general.additional_postprocessing
    
    # ------------------------------------DATA------------------------------------
    test_observation_size = cfg.data.test_observation_size
    sequence_type         = cfg.data.sequence_type
    ground_truth_length   = cfg.data.ground_truth_length
    data_type             = cfg.data.data_type
    target_type           = cfg.data.target_type
        
    # ------------------------------------MODEL------------------------------------
    test_experiment = cfg.test_experiment
    additional_postprocessing = cfg.general.additional_postprocessing

    target_type = cfg.data.target_type
    #----------------------------Weights & Biases Logging-----------------------------------------
    wandb_log       = cfg.wandb.wandb_log 
    wandb_project   = cfg.wandb.wandb_project

    group          = f'{data_type}_{sequence_type}_{target_type}_observation_size_{test_observation_size}_ground_truth_{ground_truth_length}'
    wandb_run_name = cfg.wandb.wandb_run_name   

    # logging
    if wandb_log:
        print('wandb_project: ', wandb_project)
        print('wandb_run_name: ', wandb_run_name)
        print('group: ', group)
    # endregion
    
    # region eval directories
    #-----------------------------------------EVAL DIRECTORIES-----------------------------------------
    folder_name_results        = cfg.data.folder_name_results
    folder_path_results        = os.path.join(repo_path, folder_name_results)
    folder_path_results_data   = os.path.join(folder_path_results, 'data')
    create_folder(folder_path_results)
    create_folder(folder_path_results_data)
    print('folder_path_results: ', folder_path_results)
    
    json_file_path = os.path.join(folder_path_results,  f'testing_params_{now_str}.json')
    with open(json_file_path, 'w') as f:
        json.dump(config_dict, f, indent = 4)

    if wandb_log:
        run = wandb.init(project=wandb_project, group=group, tags = [test_experiment]+additional_tags, 
                         config=config_dict, name=wandb_run_name, job_type = 'eval', dir = folder_path_results)
    
    ################################################################################################################
    ################################################################################################################
    # region eval loop
    ################################################################################################################
    ################################################################################################################
    
    test_data, test_data_ground_truth = data_loader(cfg) 

    if cfg.test_experiment == 'eval_ids60_gpt4_o_mini_comparison' or cfg.test_experiment == 'eval_microsoft_data_gpt4_o_mini_comparison':
                            #  eval_ids60_gpt4_o_mini_comparison

        if cfg.model.model_type == 'gpt4_o_mini':
            #print('extract examples list')
            examples = test_data[1000:]
            #print_list(examples[0:10])
    
        test_data = test_data[:1000]
        test_data_ground_truth = test_data_ground_truth[:1000]

    observation_file = f'observations.fasta'
    observation_path = os.path.join(folder_path_results_data, observation_file)

    ground_truth_sequence_file = f'ground_truth_sequences.fasta'
    ground_truth_sequence_path = os.path.join(folder_path_results_data, ground_truth_sequence_file)

    candidate_sequence_file = f'candidate_sequences.fasta'
    candidate_path = os.path.join(folder_path_results_data, candidate_sequence_file)

    def write_sequences(ground_truth_sequence, observed_sequences, candidate_sequence, test_index):
        # ground truth sequences
        append_fasta_seqs(sequences = [ground_truth_sequence], index = test_index, data_path = ground_truth_sequence_path)
        append_fasta_seqs(sequences = observed_sequences, index = test_index, data_path = observation_path)
        append_fasta_seqs(sequences = [candidate_sequence], index = test_index, data_path = candidate_path)

        
    def evaluation(candidate_sequence: str, ground_truth_sequence):
        candidate_sequence = filter_string(candidate_sequence) # remove non-ATGC characters
        lev_dist = Levenshtein.distance(candidate_sequence, ground_truth_sequence)
        
        if additional_postprocessing:
            if len(candidate_sequence) < ground_truth_length:
                diff_length = ground_truth_length - len(candidate_sequence)
                end_str = ''.join(random.choice('ATGC') for _ in range(diff_length))
                candidate_sequence = candidate_sequence + end_str

            elif len(candidate_sequence) > ground_truth_length:
                candidate_sequence = candidate_sequence[:ground_truth_length]

        ham_dist = hamming_distance(s1 = candidate_sequence, s2 = ground_truth_sequence)
        return candidate_sequence, ham_dist, lev_dist
    
    model = model_loader(cfg) 
    print('model loaded')

    if cfg.model.model_type == 'gpt4_o_mini':
        print('Set examples list for gpt4_o_mini')
        model.examples = examples
        #print(f'Examples set: {model.examples}')
        print(len(model.examples))
        #sys.exit()

    ham_dist_arr = np.zeros(len(test_data))
    lev_dist_arr = np.zeros(len(test_data))
    len_cand_arr = np.zeros(len(test_data))

    time_taken_array = np.zeros(len(test_data))

    len_obs_list = []
    ham_dist_obs_list = []
    lev_dist_obs_list = []

    ham_dist_dict = defaultdict(list) # Does this make sense? processing almost all the time?
    lev_dist_dict = defaultdict(list)

    print('start_loop')
    for test_index, (test_ex, gt_seq) in enumerate(zip(test_data, test_data_ground_truth)):

        print('test_index: ', test_index)
        observed_sequences = test_data[test_index].split(':')[0]
        observed_sequences = observed_sequences.split('|')
        print_list(observed_sequences)

        for obs_index, obs_seq in enumerate(observed_sequences):
            obs_seq_len = len(obs_seq)
            len_obs_list.append(obs_seq_len)
            ham_dist_obs_list.append(hamming_distance(s1 = obs_seq, s2 = gt_seq))
            lev_dist_obs_list.append(Levenshtein.distance(s1 = obs_seq, s2 = gt_seq))

        return_dict = model.inference(test_example = test_ex) 

        candidate_sequence = return_dict['candidate_sequence'] 
        time_taken = return_dict['time_taken']
        write_sequences(ground_truth_sequence = gt_seq, observed_sequences = observed_sequences, candidate_sequence = candidate_sequence, test_index = test_index)

        len_cand_arr[test_index] = len(candidate_sequence)
        ham_dist_dict[len(candidate_sequence)].append(hamming_distance(s1 = candidate_sequence, s2 = gt_seq))
        lev_dist_dict[len(candidate_sequence)].append(Levenshtein.distance(s1 = candidate_sequence, s2 = gt_seq))

        if cfg.model.model_type == 'gpt4_o_mini' and candidate_sequence == "":
            ham_dist = np.nan
            lev_dist = np.nan

        else:
            candidate_sequence, ham_dist, lev_dist = evaluation(candidate_sequence = candidate_sequence, ground_truth_sequence = gt_seq)

        ham_dist_arr[test_index] = ham_dist
        lev_dist_arr[test_index] = lev_dist
        time_taken_array[test_index] = time_taken

        print('cand:')
        print(candidate_sequence)
        print('gt:')
        print(gt_seq)
        print('-----------------------------------------------------------------------------')

    # region print and save results
    print('RUN SUMMARY:')
    avg_time_taken = np.nanmean(time_taken_array)
    print('avg_time_taken: ', avg_time_taken)

    avg_len_obs = sum(len_obs_list)/len(len_obs_list)
    print('avg_len_obs: ', avg_len_obs)

    avg_len_cand = np.nanmean(len_cand_arr)
    print('avg_len_cand: ', avg_len_cand)

    avg_ham_obs = sum(ham_dist_obs_list)/len(ham_dist_obs_list)
    avg_lev_obs = sum(lev_dist_obs_list)/len(lev_dist_obs_list)

    avg_ham = np.nanmean(ham_dist_arr)
    avg_lev = np.nanmean(lev_dist_arr)
    
    print('avg_ham_obs: ', avg_ham_obs)
    print('avg_ham: ', avg_ham)

    print('avg_lev_obs: ', avg_lev_obs)
    print('avg_lev: ', avg_lev)

    ham_success_rate = np.sum(ham_dist_arr == 0) / len(test_data)
    print('ham_success_rate: ', ham_success_rate)
    # endregion
    
    summary_dict = {    'avg_len_cand': avg_len_cand,
                        'avg_len_obs': avg_len_obs,
                        'avg_hamming_obs': avg_ham_obs,
                        'avg_levenshtein_obs': avg_lev_obs,
                        'avg_hamming': avg_ham,
                        'avg_levenshtein': avg_lev,
                        'ham_success_rate': ham_success_rate,
                        'avg_time_taken': avg_time_taken,
    }

    summary_file_path = os.path.join(folder_path_results,  f'eval_run_summary_{now_str}.json')
    with open(summary_file_path, 'w') as f:
        json.dump(summary_dict, f, indent = 4)

    # region wandb logging and finish
    if wandb_log:   
        wandb.log(summary_dict)
        wandb.finish()
    # endregion

if __name__ == "__main__":

    print("Command line arguments:", sys.argv)
    trace_reconstruction()
    