from typing import List, Tuple, Optional
import torch
from torch.utils.data import DataLoader
import random
import numpy as np

import os
import sys
from time import time  

# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_root)

print("Current working directory:", os.getcwd())


from src.loader.checkpoint import load_pretrained_bag
from src.loader.data_format.processors.subprocessors import MonomialProcessorPlus


from src.misc.utils import to_cuda
from src.evaluation.generation import generation
from src.loader.data import load_data


# Add the transformer/src directory to Python path
sos_src_path = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'sos', 'src')
sys.path.insert(0, sos_src_path)
from utils.polynomial import permute_polynomial_tokens, permute_polynomial_object
from data_generation.monomials.monomials import Polynomial, Monomial
from utils.basis_extension import basis_extension


class TransformerOracle:
    def __init__(
        self,
        model_path: str,
        num_variables: int,
        max_degree: int,
        max_coef: int,
        continuous_coefficient: bool = True,
        rational_coefficients=None,
        data_format: str = "polynomial_basis",
        batch_size: int = 32,
        device: str = "cuda"
    ):
        # Load model, tokenizer, model_name
        bag = load_pretrained_bag(model_path)
        self.model = bag['model'].to(device)
        self.tokenizer = bag['tokenizer']
        self.model_name = bag['model_name']
        self.device = device

        # Monomial processor
        self.monomial_processor = MonomialProcessorPlus(
            num_variables=num_variables,
            max_degree=max_degree,
            max_coef=max_coef,
            rational_coefficients=rational_coefficients,
            continuous_coefficient=continuous_coefficient
        )
        self.data_format = data_format
        self.batch_size = batch_size

        self.num_variables = num_variables

    def generate_from_batch(self, batch, max_length: int = 2048):
        """Generate predictions from a batch of data"""
        from src.misc.utils import to_cuda
        
        batch = to_cuda(batch)
        max_length = min(max_length, batch['labels'].shape[1])

        
        preds = generation(
            self.model, 
            self.model_name, 
            batch, 
            self.tokenizer, 
            monomial_processor=self.monomial_processor, 
            max_length=max_length
        )
        return preds

    def generate_from_string(self, tokenized_poly: str, max_length: int = 2048):
        """Generate prediction from a single tokenized polynomial string"""
        # Use monomial processor to tokenize instead of standard tokenizer
        if self.monomial_processor is not None:
            # Process using monomial processor (returns list of ProcessedMonomial)
            processed = self.monomial_processor([tokenized_poly])
            
            # Convert to tensor format
            batch_size = 1
            seq_length = len(processed[0])
            num_tokens_per_unit = len(processed[0][0].tokens) if hasattr(processed[0][0], 'tokens') else len(processed[0][0])
            
            # Create tensors
            input_ids = torch.zeros(batch_size, seq_length, num_tokens_per_unit, dtype=torch.long, device=self.device)
            coefficient_values = torch.zeros(batch_size, seq_length, dtype=torch.float, device=self.device)
            
            # Fill tensors
            for i, monomial in enumerate(processed[0]):
                if hasattr(monomial, 'tokens'):  # ProcessedMonomial
                    input_ids[0, i, :] = torch.tensor(monomial.tokens, device=self.device)
                    coefficient_values[0, i] = monomial.coefficient_value if monomial.coefficient_value is not None else 0.0
                else:  # Regular list
                    input_ids[0, i, :] = torch.tensor(monomial, device=self.device)
            
            # Create attention mask
            attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=self.device)
            
            # Create a batch with single item
            batch = {
                'input_ids': input_ids,
                'attention_mask': attention_mask
            }
            
            # Add coefficient values if using continuous coefficients
            if self.monomial_processor.continuous_coefficient:
                batch['coefficient_values'] = coefficient_values
            
            # Use the same generation function as batch processing
            preds = generation(
                self.model,
                self.model_name,
                batch,
                self.tokenizer,
                monomial_processor=self.monomial_processor,
                max_length=max_length
            )
            return preds[0] if isinstance(preds, list) else preds
        else:
            # Fallback to standard tokenizer (if no monomial processor)
            inputs = self.tokenizer(
                tokenized_poly,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length
            )
            # Move to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            print("inputs", inputs)
            # Generate
            with torch.no_grad():
                preds = self.model.generate(
                    inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_length=max_length,
                    num_beams=1,
                    tokenizer=self.tokenizer,
                    monomial_processor=self.monomial_processor,
                    do_sample=False
                )
            # Decode
            if self.monomial_processor is not None:
                decoded = self.monomial_processor.batch_decode(preds, skip_special_tokens=True)
            else:
                decoded = self.tokenizer.batch_decode(preds.long().cpu().numpy(), skip_special_tokens=True)
            return decoded[0] if isinstance(decoded, list) else decoded

    def generate_from_string_with_permutations(self, tokenized_poly: str, max_length: int = 2048, num_permutations: int = 5):
        """
        Generate prediction using multiple permutation-based approaches.
        Args:
            tokenized_poly (str): Tokenized polynomial string
            max_length (int): Maximum generation length
            num_permutations (int): Number of random permutations to try (default 5)
        Returns:
            dict: Contains 'union_basis', 'intersection_basis', 'original_basis', 'all_permuted_bases', 'permutations', 'inverse_permutations'
        """
        tokens = tokenized_poly.split()
        original_output = self.generate_from_string(tokenized_poly, max_length)
        original_basis = self._extract_basis_from_output(original_output)
        all_permuted_bases = []
        permutations = []
        inverse_permutations = []
        for _ in range(num_permutations):
            permutation = list(range(self.num_variables))
            random.shuffle(permutation)
            inverse_permutation = [0] * self.num_variables
            for i, pos in enumerate(permutation):
                inverse_permutation[pos] = i
            poly = Polynomial.from_sequence(tokens)
            permuted_poly = permute_polynomial_object(poly, permutation, num_vars=self.num_variables)
            permuted_string = " ".join(permuted_poly.to_sequence(num_vars=self.num_variables, sort_polynomials=True))
            #print(f"Permuted string: {_}: {permuted_string}")
            permuted_output = self.generate_from_string(permuted_string, max_length)
            permuted_poly_out = Polynomial.from_sequence(permuted_output.split())
            unpermuted_poly = permute_polynomial_object(permuted_poly_out, inverse_permutation, num_vars=self.num_variables)
            unpermuted_output = " ".join(unpermuted_poly.to_sequence(num_vars=self.num_variables))
            unpermuted_basis = self._extract_basis_from_output(unpermuted_output)
            all_permuted_bases.append(unpermuted_basis)
            permutations.append(permutation)
            inverse_permutations.append(inverse_permutation)
        # Union and intersection across all bases (including original)
        all_bases = [set(original_basis)] + [set(b) for b in all_permuted_bases]
        union_basis = list(set().union(*all_bases))
        intersection_basis = list(set.intersection(*all_bases)) if all_bases else []
        return {
            'union_basis': union_basis,
            'intersection_basis': intersection_basis,
            'original_basis': original_basis,
            'all_permuted_bases': all_permuted_bases,
            'permutations': permutations,
            'inverse_permutations': inverse_permutations
        }
    
    def _extract_basis_from_output(self, output_string: str) -> List[str]:
        """
        Extract basis monomials from the model output string.
        This is a simplified extraction - adjust based on your actual output format.
        
        Args:
            output_string (str): Model output string
            
        Returns:
            List[str]: List of basis monomial strings
        """
        # Split by common separators and extract monomials
        # This is a basic implementation - you may need to adjust based on your output format
        tokens = output_string.split()
        basis = []
        
        # Simple extraction: look for patterns that look like monomials
        # Adjust this based on your actual output format
        current_monomial = []
        for token in tokens:
            if token.startswith('C') or token.startswith('E'):
                current_monomial.append(token)
            elif token == '+':
                if current_monomial:
                    basis.append(" ".join(current_monomial))
                    current_monomial = []
            else:
                # Handle other tokens as needed
                pass
        
        # Don't forget the last monomial
        if current_monomial:
            basis.append(" ".join(current_monomial))
        
        return basis