import time
import os
import sys
from typing import List, Any, Dict
from torch.utils.data import DataLoader
from collections import Counter
from itertools import chain

# Add the transformer/src directory to Python path
transformer_src_path = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'transformer', 'src')
sys.path.insert(0, transformer_src_path)

from .oracle_base import OracleBase
from data_generation.monomials.monomials import Polynomial
from utils.basis_extension import basis_extension


from loader.oracle import TransformerOracle as TransformerOracleImpl
from loader.checkpoint import load_pretrained_bag
from loader.data_format.processors.subprocessors import MonomialProcessorPlus
from misc.utils import to_cuda
from evaluation.generation import generation
from loader.data import load_data


class TransformerOracle(OracleBase):
    def __init__(self, use_basis_extension: bool = False, basis_extension_params: Dict = None, permutations: int = 1, mode: str = 'single', **transformer_kwargs):
        super().__init__(use_basis_extension, basis_extension_params)
        # Initialize the actual transformer oracle
        print(f"Transformer kwargs: {transformer_kwargs}")
        self.transformer_oracle = TransformerOracleImpl(**transformer_kwargs)
        print(f"Transformer oracle initialized successfully.")
        self.permutations = permutations
        self.mode = mode

    def _parse_basis_string(self, basis_str: str) -> List:
        """
        Parse the basis string returned by the transformer oracle.
        This is a placeholder - implement based on your specific format.
        """
        basis_list = basis_str.split()
        basis_terms = [term for term in Polynomial.from_sequence(basis_list).terms]

        return basis_terms
        
    
    def generate_from_string(self, tokenized_poly: str, max_length: int = 2048):
        """Wrapper to use the transformer oracle's generate_from_string method"""
        basis_string = self.transformer_oracle.generate_from_string(tokenized_poly, max_length)

        basis_terms = self._parse_basis_string(basis_str=basis_string)
        
        tokenized_poly = tokenized_poly.replace("[C]", "C1.0")
        polynomial = Polynomial.from_sequence(tokenized_poly.split())
        
        # do basis extension if necessary
        if self.use_basis_extension:
            basis_terms = basis_extension(basis_terms, polynomial)

        return basis_terms

    
    def generate_from_batch(self, batch, max_length: int = 2048):
        """Wrapper to use the transformer oracle's generate_from_batch method"""
        return self.transformer_oracle.generate_from_batch(batch, max_length)

    def predict_basis(self, **kwargs) -> Dict:
        # User must provide the Polynomial object as 'poly' in kwargs
        poly = kwargs.get('poly')
        poly_tokens = kwargs.get('poly_tokens')
        if poly is None:
            raise ValueError("TransformerOracle requires the 'poly' argument (Polynomial object)")
        if poly_tokens is None:
            raise ValueError("TransformerOracle requires the 'poly_tokens' argument")
            
        start_time = time.time()
        
        # Convert poly_tokens to string format expected by transformer oracle
        if isinstance(poly_tokens, list):
            # Convert token list to string format
            poly_str = " ".join([str(token) for token in poly_tokens])
        else:
            poly_str = str(poly_tokens)
        
        # Use the transformer oracle to generate basis
        if self.mode == "single":
            predicted_basis_str = self.transformer_oracle.generate_from_string(poly_str)
        elif self.mode == "permutation_union":
            predicted_basis_str = " + ".join(self.transformer_oracle.generate_from_string_with_permutations(poly_str, num_permutations=self.permutations)["union_basis"])
        elif self.mode == "permutation_intersection":
            predicted_basis_str = " + ".join(self.transformer_oracle.generate_from_string_with_permutations(poly_str, num_permutations=self.permutations)["intersection_basis"])
        elif self.mode == "permutation_all":
            # all_permuted_bases is a list of lists of strings; flatten to a list of strings
            all_permuted_bases = self.transformer_oracle.generate_from_string_with_permutations(poly_str, num_permutations=self.permutations)["all_permuted_bases"]
            # Flatten the list of lists
            all_permuted_bases = [" + ".join(basis) for basis in all_permuted_bases]
        
        # Convert the predicted basis string back to tokens/monomials
        if self.mode != "permutation_all":
            predicted_basis = self._parse_basis_string(predicted_basis_str)
        else:
            predicted_basis = [self._parse_basis_string(basis) for basis in all_permuted_bases]
        
        
        oracle_time = time.time() - start_time

        result = {
            'basis': predicted_basis,
            'time': oracle_time
        }

        # Optionally perform basis extension
        if self.use_basis_extension and self.mode != "permutation_all":
            ext_start = time.time()
            extended_basis = basis_extension(predicted_basis, poly, **self.basis_extension_params)
            ext_time = time.time() - ext_start
            result['basis'] = extended_basis
            result['basis_extension_time'] = ext_time
        elif self.use_basis_extension and self.mode == "permutation_all":
            ext_start = time.time()
            # Extend each basis individually
            extended_bases = [basis_extension(basis, poly, **self.basis_extension_params) for basis in predicted_basis]

            # Count frequency of each element across all extended base
            frequency_dict = Counter(chain.from_iterable(extended_bases))

            # Create a list of (monomial, frequency) tuples to sort by frequency
            # Convert each monomial to a string for hashing/sorting purposes
            monomial_freq_pairs = [(monomial, frequency_dict[tuple(monomial) if isinstance(monomial, list) else monomial]) 
                                 for monomial in frequency_dict.keys()]
            
            # Sort by frequency (descending) and extract just the monomials
            sorted_monomials = [monomial for monomial, freq in sorted(monomial_freq_pairs, key=lambda x: x[1], reverse=True)]

            # Take the union of all extended bases without sorting
            # union_extended_basis = list(set().union(*[set(b) for b in extended_bases]))

            ext_time = time.time() - ext_start

            # choose either union_extended_basis or sorted_monomials
            result['basis'] = sorted_monomials
            result['basis_extension_time'] = ext_time

        return result

