import os
import sys
import random

import pickle

from data_pkg.data_generation import data_generation
from utils.data_functions import load_data_from_file
from utils.helper_functions import get_repo_path, extract_elements
from utils.noisy_dna import remove_trailing_Cs

script_dir       = os.path.dirname(__file__)
print("script_dir_utils: ", script_dir)

repo_path = get_repo_path(script_dir,2)
print("repo_path: ", repo_path)

data_pkg_dir = os.path.join(repo_path,'src','data_pkg')
print("data_pkg_dir: ", data_pkg_dir)


def data_loader(cfg):

    test_data = []
    test_data_ground_truth = [] 

    # region load or generate test data
    if cfg.data.data_generation == 'online':
            
            print('-------------------------------------------Generate Data---------------------------------------------')
            for _ in range(cfg.data.test_size):
                substitution_probability = random.uniform(cfg.data.substitution_probability_lb, cfg.data.substitution_probability_ub)
                insertion_probability = random.uniform(cfg.data.insertion_probability_lb, cfg.data.insertion_probability_ub)
                deletion_probability = random.uniform(cfg.data.deletion_probability_lb, cfg.data.deletion_probability_ub)
                
                channel_statistics = {'substitution_probability': substitution_probability, 
                                    'insertion_probability': insertion_probability, 
                                    'deletion_probability':  deletion_probability}

                data_pairs = data_generation(data_set_size = 1, 
                                    observation_size = cfg.data.test_observation_size, 
                                    length_ground_truth = cfg.data.ground_truth_length, 
                                    channel_statistics = channel_statistics, 
                                    target_type = cfg.data.target_type,    
                                    data_type = cfg.data.data_type)
                          
                data_ex = extract_elements(data_pairs, cfg.data.target_type)[0]
                test_data.append(data_ex[0])

                ground_truth_ex = extract_elements(data_pairs, 'ground_truth')[0]
                test_data_ground_truth.append(ground_truth_ex[0])
                print('generated data done')
    
    elif cfg.data.data_generation == 'offline': 
        
        print('-------------------------------------------Load Data---------------------------------------------')
        eval_data_path = cfg.data.eval_data_file
        print('LOADING DATA FROM: ', eval_data_path)
        test_data = load_data_from_file(eval_data_path)
        
        if hasattr(cfg.data, 'test_size')  and cfg.data.test_size is not None:
            test_data = test_data[:cfg.data.test_size]

        test_data_temp = []
        
        for test_index, test_ex in enumerate(test_data):
            ground_truth_seq = test_ex.split(':')[1]

            test_l = test_ex.split(':')[0]
            test_l = test_l.split('|')

            if hasattr(cfg.data, 'remove_trailing_Cs') and cfg.data.remove_trailing_Cs:
                print('Removing trailing Cs...')
                 # Your code to handle the removal of trailing Cs
                for j, test_seq in enumerate(test_l):
                    test_l[j] = remove_trailing_Cs(test_seq)  
            
            if len(test_l) < cfg.data.test_observation_size:
                continue
            else:
                test_l = test_l[:cfg.data.test_observation_size]
                test_l = '|'.join(test_l)
                test_str = test_l + ':' + ground_truth_seq
                #test_data[test_index] = test_str
                test_data_temp.append(test_str)
                test_data_ground_truth.append(ground_truth_seq)
        test_data = test_data_temp
        print('loading data done')
    
    else:
        print('eval_pkg/loader.py: ERROR - data_generation not recognized.')
        sys.exit()

    return test_data, test_data_ground_truth

def model_loader(cfg):

    model_type = cfg.model.model_type

    # region load model
    print('-------------------------------------------Load Models---------------------------------------------')
    if model_type == 'gpt':
        
        # gpt package
        from gpt_pkg.model import GPTConfig, GPT
        from gpt_pkg.load_gpt import load_transformer_model 
        from eval_pkg.GPT_Inference import GPT_Inference

        if hasattr(cfg, 'compile'): 
            compile = cfg.model.compile
        else:
            compile = True

        meta_path = os.path.join(data_pkg_dir, f'meta_{cfg.data.sequence_type}.pkl')

        print(f"Loading meta from {meta_path}...")
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
        stoi, itos = meta['stoi'], meta['itos']
        vocab_size = len(itos)
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])
       
        run_time        = cfg.model.gpt_params.run_time
        checkpoint_folder = cfg.model.gpt_params.checkpoint_folder
        checkpoint_name   = cfg.model.gpt_params.checkpoint_name

        checkpoint_dir = os.path.join(repo_path, 'model_checkpoints', cfg.project, checkpoint_folder)
        print(checkpoint_dir)

        # gpt params
        max_new_tokens = cfg.model.gpt_params.max_new_tokens
        temperature    = cfg.model.gpt_params.temperature     # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        top_k          = cfg.model.gpt_params.top_k  # retain only the top_k most likely tokens, clamp others to have 0 probability
        device         = cfg.model.gpt_params.device
        sampling       = cfg.model.gpt_params.sampling
        beam_width     = cfg.model.gpt_params.beam_width
        constrained_generation = cfg.model.gpt_params.constrained_generation

        train_data_set = f'{cfg.data.sequence_type}_{cfg.data.target_type}_{cfg.data.observation_size}'

        gpt_model, encode, decode, vocab_size, ctx = load_transformer_model(model_name = checkpoint_name, device = device, dataset = train_data_set, 
                                                                    checkpoint_dir = checkpoint_dir, data_pkg_dir = data_pkg_dir, compile=compile)
    
        num_params = gpt_model.get_num_params()
        label_smoothing = gpt_model.config.label_smoothing

        inference_params = {
            'test_observation_size': cfg.data.test_observation_size,
            'ground_truth_length': cfg.data.ground_truth_length,
            'target_type': cfg.data.target_type,
            'model': gpt_model,
            'model_name': checkpoint_name,
            'encode': encode,
            'decode': decode,
            'vocab_size': vocab_size,
            'temperature': temperature,
            'device': device,
            'top_k': top_k,
            'ctx': ctx,
            'num_params': num_params,
            'label_smoothing': label_smoothing,
            'sampling': sampling,
            'beam_width': beam_width,
            'max_new_tokens': max_new_tokens,
            'itos': itos,
            'stoi': stoi,
            'constrained_generation': constrained_generation
            }
        model = GPT_Inference(inference_params) 
            
    print('-------------------------------------------Model Loaded---------------------------------------------')
    return model 