import numpy as np
import torch

import time

from data_pkg.data_generation import unnest_strings
from utils.print_functions import print_list
from gpt_pkg.beam_search import beam_search
from eval_pkg.majority_vote import majority_merge, extended_majority_vote

def gpt_prediction(test_example, **kwargs):

    """
    This function generates a prediction using the GPT model. We predict the candidate sequence directly,
    based on a list of observed sequences.

    Args:
    gpt_path (str): The path to the .fasta file to write the prediction to.
    **kwargs: The keyword arguments to pass to the function.

    Returns:
    float: The time taken to generate the prediction.
    """

    model       = kwargs['model']
    ctx         = kwargs['ctx']
    temperature = kwargs['temperature']
    top_k       = kwargs['top_k']
    device      = kwargs['device']

    encode = kwargs['encode']
    decode = kwargs['decode']
    max_new_tokens = kwargs['max_new_tokens']

    #test_data_example   = kwargs['test_data_example']
    ground_truth_length = kwargs['ground_truth_length']
    sampling = kwargs['sampling'] # 'greedy' 'beam search'

    itos = kwargs['itos']
    
    constrained_generation = kwargs['constrained_generation']    
    
    input_text = test_example.split(':')[0]
    if sampling == 'greedy':
        # GPT prediction
        with torch.no_grad():
            with ctx:
                start_time = time.time()
                input_ids = encode(input_text + ':')
                x = (torch.tensor(input_ids, dtype=torch.long, device=device)[None, ...])
                #y = model.generate(x, max_new_tokens, temperature = temperature, top_k = top_k)
                y = model.generate_cpred(idx = x, max_new_tokens = max_new_tokens, temperature=temperature, 
                                         top_k=None, sampling = sampling, ground_truth_length = ground_truth_length, 
                                         constrained_generation = constrained_generation, itos = itos)       
                output_text = decode(y[0].tolist())

    elif sampling == 'beam_search':
        beam_width = kwargs['beam_width']   
        with torch.no_grad():
            with ctx:
                start_time = time.time()
                input_ids = encode(input_text + ':')
                x = (torch.tensor(input_ids, dtype=torch.long, device=device)[None, ...])
                output_ids_beam_search = beam_search(model = model, beam_width = beam_width, sequence_length = max_new_tokens, x = x, device = device)
                output_ids_beam_search = output_ids_beam_search[0].tolist()
                output_text = decode(output_ids_beam_search)
    
    candidate = output_text.split(':')[1]
    candidate = candidate[0:ground_truth_length]
    end_time = time.time()
    time_taken = end_time - start_time

    return_dict = { 'candidate_sequence': candidate,
                     'time_taken': time_taken}

    return return_dict

def gpt_alignment(test_example, **kwargs):

    model = kwargs['model']
    encode = kwargs['encode']
    decode = kwargs['decode']
    ctx = kwargs['ctx']
    max_new_tokens = kwargs['max_new_tokens']
    temperature = kwargs['temperature']
    top_k = kwargs['top_k']
    device = kwargs['device']
    alignment_size = kwargs['test_observation_size']
    target_type = kwargs['target_type']

    sampling = kwargs['sampling'] # 'greedy' 'beam_search'

    input_text = test_example.split(':')[0]
    if sampling == 'greedy':    
        # GPT alignment
        with torch.no_grad():
            with ctx:
                start_time = time.time()
                input_ids = encode(input_text + ':')
                x = (torch.tensor(input_ids, dtype=torch.long, device=device)[None, ...])
                
                #start_time = time.time()
                y = model.generate(x, max_new_tokens, temperature = temperature, top_k = top_k)

                output_text = decode(y[0].tolist())
    elif sampling == 'beam_search':
        beam_width = kwargs['beam_width']
    
        # GPT alignment
        with torch.no_grad():
            with ctx:
                start_time = time.time()
                input_ids = encode(input_text + ':')
                x = (torch.tensor(input_ids, dtype=torch.long, device=device)[None, ...])
            
                output_ids_beam_search = beam_search(model = model, beam_width = beam_width, sequence_length = max_new_tokens, x = x, device = device)
                output_ids_beam_search = output_ids_beam_search[0].tolist()
                output_text = decode(output_ids_beam_search)

    output_text = output_text.split('#')[0]
    predicted_alignment = output_text.split(':')[1]
    
    if 'MSA' in target_type:
        predicted_alignment_list = predicted_alignment.split('|')
        predicted_alignment_list = predicted_alignment_list[0:alignment_size]
        min_length = min(len(s) for s in predicted_alignment_list)
        predicted_alignment_list = [s[:min_length] for s in predicted_alignment_list]
    elif 'NESTED' in target_type:
        predicted_alignment_list = unnest_strings(nested_str = predicted_alignment, num_segments=alignment_size)
        min_length = min(len(s) for s in predicted_alignment_list)
        predicted_alignment_list = [s[:min_length] for s in predicted_alignment_list]
    
    print('alignment list')
    print_list(predicted_alignment_list)
    
    if 'ex' in target_type:
        candidate_seq = extended_majority_vote(predicted_alignment_list, length = 10000, check_length = alignment_size)
    elif 'std' in target_type:
        candidate_seq = majority_merge(predicted_alignment_list, weight = 0.4)
    
    else:
        raise ValueError('target_type not recognized')

    end_time = time.time()
    time_taken = end_time - start_time

    return_dict = { 'candidate_sequence': candidate_seq, 
                   'time_taken': time_taken,
                   'predicted_alignment_list': predicted_alignment_list}
            
    return return_dict



class GPT_Inference():
    
    def __init__(self, inference_params): 
        
        self.inference_params = inference_params 

    def inference(self, test_example): 

        """
        ... 
        """
        target_type = self.inference_params['target_type']

        if 'MSA' in target_type or 'NESTED' in target_type: 
            return_dict = gpt_alignment(test_example, **self.inference_params) 

        elif 'CPRED' in target_type:
            return_dict = gpt_prediction(test_example, **self.inference_params)

        else:
            raise ValueError('target_type not recognized')
        
        return return_dict
    




    
    
