import os
import torch
import accelerate
import numpy as np
import random
from dataclasses import dataclass, field
from huggingface_hub.file_download import hf_hub_download
from accelerate import Accelerator
from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    set_seed,
)


@dataclass()
class GenerateConfig:
    model_name_or_path: str = field(default="xhan77/ssdlm")
    max_seq_length: int = field(default=64)
    conditional: bool = field(default=True)
    decode_total_gen_len: int = field(default=32) 
    one_hot_value: int = field(default=5)
    seed: int = field(default=2022) #
    fast_tokenizer: bool = field(default=False)

    def __post_init__(self):
        self.accelerator = Accelerator()
        self.device = self.accelerator.device
        accelerate.utils.set_seed(self.seed, device_specific=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=self.fast_tokenizer)
        self.model = self._get_model()
        self.vocab_size = self.model.get_input_embeddings().weight.size(0)
        self.hidden_size = self.model.get_input_embeddings().weight.size(1)
        self.embedding_sum_layer = self._get_embedding_sum_layer()
        self.timestep_layer = self._get_timestep_layer()
        self._set_deterministic_mode(self.seed)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def _set_deterministic_mode(self, seed):
        set_seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)

        torch.backends.deterministic = True
        torch.backends.benchmark = False

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def _get_model(self):
        model_config = AutoConfig.from_pretrained(self.model_name_or_path)
        model = AutoModelForMaskedLM.from_pretrained(self.model_name_or_path, from_tf=False, config=model_config)
        model.resize_token_embeddings(len(self.tokenizer))
        model = self.accelerator.prepare(model)
        return model

    def _get_embedding_sum_layer(self):
        embedding_sum_layer = torch.nn.Linear(self.vocab_size, self.hidden_size, bias=False)
        _stdict = torch.load(os.path.join(hf_hub_download(self.model_name_or_path, "embed_sum_layer.pt")))
        _stdict = dict((_k[len("module."):], _stdict[_k]) if _k.startswith("module.") else (_k, _stdict[_k]) for _k in _stdict)
        embedding_sum_layer.load_state_dict(_stdict)
        embedding_sum_layer = self.accelerator.prepare(embedding_sum_layer)
        return embedding_sum_layer

    def _get_timestep_layer(self):
        timestep_layer = torch.nn.Linear(1, self.hidden_size, bias=True)
        _stdict = torch.load(os.path.join(hf_hub_download(self.model_name_or_path, "timestep_layer.pt")))
        _stdict = dict((_k[len("module."):], _stdict[_k]) if _k.startswith("module.") else (_k, _stdict[_k]) for _k in _stdict)
        timestep_layer.load_state_dict(_stdict)
        timestep_layer = self.accelerator.prepare(timestep_layer)
        return timestep_layer


@dataclass()
class UnconstrainedGenerationConfig(GenerateConfig):
    decode_log_interval: int = field(default=100)
    total_t: int = field(default=1000)
    top_p: float = field(default=0.95)
    controlled: bool = field(default=False)

@dataclass()
class ControlledGenerationConfig(GenerateConfig):
    decode_log_interval: int = field(default=500)
    total_t: int = field(default=5000)
    top_p: float = field(default=0.2)
    ctr_model_name: str = field(default="cardiffnlp/twitter-roberta-base-sentiment") # RoBERTa-based classifiers would have the same tokenizer as SSD-LM; GPT-based classifiers use a same tokenization method, but may need to re-map the vocab indices
    ctr_opt_label_idx: int = field(default=0) # check the definition of label indices in the selected classifier; in this example, 0 -> negative, 2 -> positive
    decode_ctr_lr: float = field(default=2000.0)
    controlled: bool = field(default=True)

    def __post_init__(self):
        super().__post_init__()
        self.ctr_model = AutoModelForSequenceClassification.from_pretrained(self.ctr_model_name).to(self.accelerator.device)