import numpy as np
import torch
import subprocess
import os
import time
import sys
from typing import List, Dict, Any


from src.data_pkg.data_generation import unnest_strings
from src.utils.helper_functions import filter_string
from src.utils.print_functions import print_list
from src.gpt_pkg.beam_search import beam_search
from src.eval_pkg.majority_vote import majority_merge, extended_majority_vote


def gpt_prediction(test_examples, attn_mask, **kwargs) -> Dict[str, Any]:
    # configs
    model        = kwargs['model']
    ctx          = kwargs['ctx']
    device       = kwargs['device']
    decode       = kwargs['decode']
    max_new_tok  = kwargs['ground_truth_length'] # want to generate ground truth many sequneces
    temperature  = kwargs['temperature']
    sampling     = 'greedy' if kwargs['greedy'] else 'beam_search'
    ground_len   = kwargs['ground_truth_length']
    itos         = kwargs['itos']
    stoi         = kwargs['stoi']
    constrained  = kwargs['constrained_generation'] # only generate ACTG
    beam_width   = kwargs.get('beam_width', 5)

    pad_id   = stoi['#']

    # generate
    start_time = time.time()
    if sampling == 'greedy':
        with torch.no_grad(), ctx:
            Y = model.generate_cpred(idx = test_examples, attn_mask = attn_mask , max_new_tokens = max_new_tok, temperature = temperature, top_k = None, sampling = 'greedy', constrained_generation = constrained, itos = itos)
        

    else:  # beam_search 
        decoded = []
        with torch.no_grad(), ctx:
            Y_beams = beam_search(model = model, beam_width = beam_width, sequence_length = max_new_tok, x = test_examples, attn_mask = attn_mask, device = device)
        Y = Y_beams[:, 0, :]

    decoded = [decode(row) for row in Y.tolist()]

    # trim to label length
    candidates = [txt.split(':', 1)[1][:ground_len] for txt in decoded]

    return {'candidate_sequences': candidates, 'time_taken': time.time() - start_time}



def gpt_alignment(test_examples, attn_mask, alignment_size, **kwargs):
    """
    Generates an alignment (MSA or nested) using the GPT model and applies postprocessing to prepare the output for majority voting.
    Alignment_size (int) is expected number of aligned sequences (equal to cluster size).

    MSA decoding (when 'MSA' in target_type):
        - If output contains '|' tokens, they are treated as alignment separators.
        - If not, the output is evenly split into alignment_size chunks as fallback.
        - All sequences are truncated to the length of the shortest valid one.

    Nested decoding (when 'NESTED' in target_type):
        - Output is un-nested into alignment_size (i.e. cluster size) equal-length parts.
        - All parts are truncated to the length of the shortest one.

    Merging:
        - std uses majority_merge
        - ex uses extended_majority_vote
    """

    model = kwargs['model']
    decode = kwargs['decode']
    ctx = kwargs['ctx']
    block_size = kwargs['block_size']
    temperature = kwargs['temperature']
    top_k = kwargs['top_k']
    device = kwargs['device']
    target_type = kwargs['target_type']
    itos = kwargs['itos']
    stoi = kwargs['stoi']

    pad_id   = stoi['#']

    try: 
        # generate max new tokens until out of block length 
        if 'MSA' in target_type:
            max_new_tokens = block_size-len(test_example.split(':')[0]) 
        elif 'NESTED' in target_type:
            max_new_tokens = block_size-len(test_example.split(':')[0])
    except Exception as e: 
        print(f"Error in assigning max new tokens {e}")

    sampling = kwargs['sampling'] # 'greedy' 'beam_search'

    # generate
    start_time = time.time()

    if sampling == 'greedy':    
        with torch.no_grad(), ctx:
            Y = model.generate(test_examples, attn_mask ,max_new_tokens, temperature = temperature, top_k = top_k, itos=itos)
        

    elif sampling == 'beam_search':
        beam_width = kwargs['beam_width']
        with torch.no_grad(), ctx:
            Y_beams = beam_search(model = model, beam_width = beam_width, sequence_length = max_new_tokens, x = test_examples, attn_mask=attn_mask, device = device)
        Y  = Y_beams[:, 0, :] 
    
    decoded = [decode(row) for row in Y.tolist()]

    out_candidates = []
    out_alignments = []

    for ex, txt in zip(test_examples, decoded):
        # strip prefix and EOS
        try:
            body = txt.split(':', 1)[1].split('#', 1)[0]
        except IndexError:
            body = txt.split('#', 1)[0]

        # split into alignment pieces
        if 'MSA' in target_type:
            if '|' in body:
                pieces = body.split('|')
            else:
                # even split fallback
                chunk = len(body) // alignment_size
                pieces = [body[i*chunk:(i+1)*chunk] for i in range(alignment_size)]
        else:  # nested
            pieces = unnest_strings(nested_str=body, num_segments=alignment_size)

        # truncate all to shortest
        min_len = min(len(s) for s in pieces)
        pieces = [s[:min_len] for s in pieces]

        # merge
        if 'ex' in target_type:
            candidate = extended_majority_vote(pieces, length=min_len, check_length=alignment_size)
        else:
            candidate = majority_merge(pieces, weight=0.4)

        out_candidates.append(candidate)
        out_alignments.append(pieces)

    return {
        'candidate_sequences':       out_candidates,
        'predicted_alignment_list':  out_alignments,
        'time_taken':                time.time() - start_time
    }



class GPT_Inference:
    def __init__(self, inference_params):
        self.inference_params = inference_params

    def inference(self, test_examples, alignment_size=None):
        """
        test_examples: List[str] of examples that will be batched.
        returns a dict with keys: candidate_sequences (List[str]) and time taken.
        """
        target_type = self.inference_params['target_type']
        stoi        = self.inference_params['stoi']
        encode      = self.inference_params['encode']
        decode      = self.inference_params['decode']
        device      = self.inference_params['device']

        pad_id   = stoi['#']
        colon_id = stoi[':']

        # batching
        prefixes      = [ex.split(':')[0] for ex in test_examples]
        enc_prefixes  = [encode(p) for p in prefixes]

        B              = len(enc_prefixes)
        max_prefix_len = max(len(e) for e in enc_prefixes)
        T              = max_prefix_len + 1           # +1 for the colon

        # left-padded tensor of input IDs
        X = torch.tensor([[pad_id] * (max_prefix_len - len(e))  + e + [colon_id] for e in enc_prefixes], dtype=torch.long, device=device)

        # matching attention mask  (False = pad, True = real)
        attn_mask = torch.tensor([[False] * (max_prefix_len - len(e)) + [True] * (len(e) + 1) for e in enc_prefixes], dtype=torch.bool, device=device)

        # inspect
        #print("Example padded sequences:")
        #for i in range(min(5, B)):
        #    decoded = decode(X[i].tolist())
        #    real    = attn_mask[i].sum().item()
        #    print(f"[{i}] {decoded} | {real} real tokens")

        if 'MSA' in target_type or 'NESTED' in target_type:
            return gpt_alignment(X, attn_mask, alignment_size, **self.inference_params)

        elif 'CPRED' in target_type:
            return gpt_prediction(X, attn_mask, **self.inference_params)

        else:
            raise ValueError(f"Unknown target_type: {target_type}")
