from transformers import AutoModelForSequenceClassification
from trlx import trlx
from sklearn.metrics.pairwise import manhattan_distances
from lm_utils import *
from trlx.data.default_configs import default_ppo_config

warnings.filterwarnings("ignore")

LAM_ADV = 0.5
LAM_DIV = 500.0
EXPLOIT_MODEL = 'gpt2-large'
ENSEMBLE_SIZE = 5
N_TRAIN_STEPS = 400000


def get_classifier_fn(classifier_model=CLASSIFIER_MODEL):
    tokenizer = AutoTokenizer.from_pretrained(classifier_model)
    models = [AutoModelForSequenceClassification.from_pretrained(f'./models/{classifier_model}_classifier_{i}').to(DEVICE)
              for i in range(ENSEMBLE_SIZE)]
    sub_batch_size = 512
    def classify(responses):
        with torch.no_grad():
            all_results = []
            for model in models:
                if len(responses) > sub_batch_size:
                    all_model_results = []
                    for i in range(0, len(responses), sub_batch_size):
                        inputs = tokenizer(responses[i: i+sub_batch_size], padding="max_length",
                                           truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(DEVICE)
                        outputs = model(**inputs).logits
                        all_model_results.append(outputs[:, 1])
                    model_results = torch.cat(all_model_results)
                else:
                    inputs = tokenizer(responses, padding="max_length", truncation=True,
                                       max_length=MAX_LENGTH, return_tensors='pt').to(DEVICE)
                    outputs = model(**inputs).logits
                    model_results = outputs[:, 1]
                all_results.append(model_results)
            all_results = torch.mean(torch.stack(all_results), dim=0)
            return all_results
    return classify


def get_encoder_fn():
    def get_gpt2_embedding_numpy(sentences):
        return get_gpt2_embedding(sentences).cpu().numpy()
    return get_gpt2_embedding_numpy


def get_dist_fn():
    def l1_dist(txt_embeddings: torch.tensor) -> np.ndarray:
        return torch.tensor(manhattan_distances(txt_embeddings)).to(DEVICE)
    return l1_dist


def get_reward_fn(classifier_fn, encoder_fn, dist_fn):
    def reward_fn(samples, **kwargs) -> torch.tensor:
        sample_lens = [len(s) for s in samples]
        samples = [s if len(s) > 0 else 'The' for s in samples]
        with torch.no_grad():
            responses = target_lm(samples, pad_token_id=50256)
            responses = [r[0]['generated_text'] for r in responses]
            responses = [r[len(s):] for r, s in zip(responses, samples)]
            responses = [remove_leading_whitespace(r) for r in responses]
            embeddings = encoder_fn(samples)
            dist_matrix = dist_fn(embeddings)
            div_reward = -1 * (torch.mean(torch.exp(-dist_matrix), dim=1) * LAM_DIV)  # diversity reward
            del dist_matrix
            adv_reward = classifier_fn(responses) * LAM_ADV  # adversarial reward
            rewards = div_reward + adv_reward
            rewards = torch.clip(rewards, -5, 5)
            for i, sl in enumerate(sample_lens):  # penalize sentences that are too short
                if sl <= 10:
                    rewards[i] = -5
            return rewards.tolist()
    return reward_fn


def get_config():
    config = default_ppo_config()
    config.train.tracker = None
    config.train.total_steps = N_TRAIN_STEPS
    config.train.epochs = 1000
    config.train.checkpoint_interval = 1000
    config.train.eval_interval = 500
    config.model.model_path = EXPLOIT_MODEL
    config.method.gen_kwargs.update({'max_new_tokens': 10})
    config.train.batch_size = 4096
    config.method.init_kl_coef = 0.05
    config.method.target = 4
    config.optimizer.kwargs.update({'lr': 1e-6})
    config.model.num_layers_unfrozen = 1
    return config


if __name__ == '__main__':

    print(f'Running exploit step...')

    config = get_config()
    classifier_fn = get_classifier_fn()
    encoder_fn = get_encoder_fn()
    dist_fn = get_dist_fn()
    reward_fn = get_reward_fn(classifier_fn, encoder_fn, dist_fn)

    print(f'Running rl training for {N_TRAIN_STEPS} steps...')
    trainer = trlx.train(reward_fn=reward_fn, config=config)
    print('Saving...')
    trainer.save_pretrained('./models/exploit_generator')
    print('Done :)')
