from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
from abstract_cf.text_generation.utils import sample_from_model
from abstract_cf.text_generation.utils import load_model
import torch


device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


@dataclass
class State:
    ''' 
    Represents the state as intended in the paper. 
    Typically, the state is a tuple S=(X, θ) 
    where X is the input (i.e. usually a prompt)
    and θ is the model's parameters. 
    For the sake of generality of our implementation, θ is a callable function,
    that is expected to implement the model and associated parameters. 
    '''
    policy_function: callable
    inputs: any

    inputs_text: str | list[str] = None       # not strictly necessary, used as metadata and for evaluation
    parallel_sampling: bool = True
    # this could be a 'generate' argument instead
    sampling_batch_size: int | None = 5
    sampling_max_length: int | None = None

    def call_policy(self, n_samples: int = 1):
        # TODO specify the return type of this function, possibly cast to torch.tensor
        if self.parallel_sampling:
            samples = self.policy_function(
                self.inputs, 
                n_samples=n_samples, 
                batch_size=self.sampling_batch_size, 
                max_length=self.sampling_max_length
            )
        else: 
            samples = [
                self.policy_function(self.inputs) 
                for _ in range(n_samples)
            ]
        return samples


def construct_state(
    prompt: str,
    n_prompt_tokens: int,
    max_length: int,
    model_name: str | None = None,
    model: AutoModelForCausalLM | None = None,
    tokenizer: AutoTokenizer | None = None,
    sampling_batch_size: int = 5,
    sampling_max_length: int = None,
) -> State:
    assert model is not None or model_name is not None, "Either model or model_name must be provided"
    assert tokenizer is not None or model_name is not None, "Either tokenizer or model_name must be provided"

    if model is None:
        model = load_model(model_name)
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            model_max_length=max_length, 
            padding_side="right", 
            use_fast=False, 
            trust_remote_code=True
    )
    inputs = tokenizer(
        prompt, 
        return_tensors='pt',
        truncation=True,
        max_length=n_prompt_tokens,
    ).to(device)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id
    state = State(
        policy_function=lambda inputs, n_samples, batch_size, max_length: sample_from_model(
            model, 
            tokenizer, 
            inputs, 
            n_samples,
            batch_size=batch_size,
            max_length=max_length,
        )[0],
        inputs=inputs,
        inputs_text=tokenizer.decode(inputs['input_ids'][0]),
        sampling_batch_size=sampling_batch_size,
        sampling_max_length=sampling_max_length,
    )
    return state 

