# %%
# ! pip install -U accelerate -q
# ! pip install -U transformers -q
# ! pip install datasets -q
# ! pip install wandb -q

# %%
import numpy as np
import os
import transformers
import itertools
import pandas as pd
import math
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from transformers import (
    set_seed,
)
import wandb
import pickle
import string
from datasets import Dataset, DatasetDict, load_dataset
import torch
import torch.nn.functional as F
import logging

# %%
import json

with open('llms/artifacts/random_strings/data.json') as f:
    data_rs = json.load(f)

# %%
def get_seed_for_exp(seed_id):
    return data_rs[f'seed_id-{seed_id}']['seed']

def get_string_for_exp(seed_id):
    return data_rs[f'seed_id-{seed_id}']['data']['alphabet_size-26']['data']['uniform_distribution']['data']['num_tokens-1024']

for PRETRAINED_CKPT_SEED_ID in range(0, 10):
# %%
    NUM_SEQUENCES = 2
    SEQUENCE_LENGTH = 512
    LEARNING_RATE = 1e-5
    EVALUATION_STRATEGY = 'epoch'
    SAVE_STRATEGY = 'no'
    NUM_TRAIN_EPOCHS = 40
    WEIGHT_DECAY = 0
    WARMUP_RATIO = 0.05
    LR_SCHEDULER = 'linear'
    PRETRAINED_CKPT_NS = 1
    PRETRAINED_CKPT_SL = 1024
    PRETRAINED_CKPT = 100
    PRETRAINED_CKPT_CHARACTERS = 26
    PRETRAINED_CKPT_SEED_ID = PRETRAINED_CKPT_SEED_ID
    PRETRAINED_CKPT_SEED = get_seed_for_exp(PRETRAINED_CKPT_SEED_ID)
    dat_id = f'sid-{PRETRAINED_CKPT_SEED_ID}_a-{PRETRAINED_CKPT_CHARACTERS}_t-1024_p-{PRETRAINED_CKPT_NS}'
    PRE_TRAINING_CHECKPOINT = f'llms/artifacts/random_strings/models/pythia-1b/{dat_id}/epoch_{PRETRAINED_CKPT}/'

    params_dict = {
        'NUM_SEQUENCES': NUM_SEQUENCES,
        'SEQUENCE_LENGTH': SEQUENCE_LENGTH,
        'LEARNING_RATE': LEARNING_RATE,
        'SEED': PRETRAINED_CKPT_SEED,
        'NUM_TRAIN_EPOCHS': NUM_TRAIN_EPOCHS,
        'PRE-TRAINING-CHECKPOINT': PRE_TRAINING_CHECKPOINT,
        'LR_SCHEDULER': LR_SCHEDULER,
        'WARMUP_RATIO': WARMUP_RATIO,
        'PRE_TRAINED_CKPT_NS': PRETRAINED_CKPT_NS,
        'PRE_TRAINED_CKPT_SL': PRETRAINED_CKPT_SL,
    }

    MODEL_SIZE = '1b'
    WANDB_PROJECT_NAME = "SDDD-Experiments-Reverse-For-Paper"
    OUTPUT_DIR = 'custom_ft/saves'

    RUN_NAME = MODEL_SIZE + "-ns-" + str(NUM_SEQUENCES) + "-sl-" + \
            str(SEQUENCE_LENGTH) + "-lr-" + str(LEARNING_RATE) + "-lr_scheduler-" + str(LR_SCHEDULER) + "-warmup-ratio-" + str(WARMUP_RATIO) + \
                "-seed-" + str(PRETRAINED_CKPT_SEED) + "-epochs-" + str(NUM_TRAIN_EPOCHS) + "-dat_id-" + str(dat_id) + 'reversed'
    OUTPUT_DIR = os.path.join(OUTPUT_DIR, RUN_NAME)

    # Make sure output directory exists
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    MODEL_NAME = f"EleutherAI/pythia-{MODEL_SIZE}"

    os.environ["WANDB_PROJECT"] = WANDB_PROJECT_NAME
    os.environ["WANDB_WATCH"] = "all"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    # Set API key (for wandb)
    os.environ["WANDB_API_KEY"] = "KEY"

    wandb.init(project=WANDB_PROJECT_NAME)
    wandb.run.name = RUN_NAME
    wandb.config.update(params_dict)

    # %%
    set_seed(PRETRAINED_CKPT_SEED)

    # %%
    RUN_NAME

    # %%
    # model = GPTNeoXForCausalLM.from_pretrained(
    #   MODEL_NAME,
    #   revision=PRE_TRAINING_CHECKPOINT
    # )

    tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME
    )

    # %%
    logger = logging.getLogger(__name__)

    # %%
    tokenizer.pad_token = tokenizer.eos_token

    # %%
    def generate_random_strings(seed_id):
        sequence = get_string_for_exp(seed_id)
        sequences = []
        for i in range(NUM_SEQUENCES):
            sequences.append(sequence[i*SEQUENCE_LENGTH:(i+1)*SEQUENCE_LENGTH])
        print(sequences)
        print(len(sequences))
        assert len(sequences) == NUM_SEQUENCES
        assert len(sequences[0]) == SEQUENCE_LENGTH
        # print(sequences)
        dataset = Dataset.from_dict(
            {
                "text": sequences,
            }
        )
        datasets = DatasetDict(
            {
                "train": dataset,
                "test": dataset,
            }
        )
        datasets.set_format("torch")
        return datasets

    def generate_strings(num_sequences, sequence_length, rng):
        letters = string.ascii_lowercase
        sequences = rng.choice(list(letters), size=(1, sequence_length*num_sequences))

        reshaped_sequences = []
        for i in range(num_sequences):
            reshaped_sequences.append(
                sequences[0, i*sequence_length:(i+1)*sequence_length]
            )
        sequences = reshaped_sequences
        print(sequences)
        strs = ["".join(seq) for seq in sequences]
        return strs

    def encode_character_wise(tokenizer, dataset):
        def characterwise_encoding(example):
            sequences = example["text"]
            max_length = max(len(s) for s in sequences)
            sequence_token_ids = []
            sequence_token_masks = []
            for sequence in sequences:
                sequence_chars = list(sequence)
                encoded_chars = tokenize(
                    tokenizer,
                    sequence_chars,
                    max_length=1,
                )
                # add padding
                num_padding = max_length - len(sequence)
                padded_input_ids = torch.cat(
                    (
                        torch.tensor(
                            [tokenizer.pad_token_id] * num_padding, dtype=torch.long
                        ),
                        encoded_chars.input_ids.squeeze(1),
                    )
                )
                padded_attention_mask = torch.cat(
                    (
                        torch.tensor([0] * num_padding, dtype=torch.long),
                        encoded_chars.attention_mask.squeeze(1),
                    )
                )
                sequence_token_ids.append(padded_input_ids)
                sequence_token_masks.append(padded_attention_mask)
            return {
                "input_ids": torch.stack(sequence_token_ids),
                "attention_mask": torch.stack(sequence_token_masks),
            }

        return dataset.map(
            characterwise_encoding,
            batched=True,
            batch_size=NUM_SEQUENCES,
        )

    def tokenize(tokenizer, text, max_length):
        if tokenizer.padding_side == "right":
            logger.warning("Padding side is right, setting it to left")
            tokenizer.padding_side = "left"
        if max_length is None:
            padding = "longest"
        else:
            padding = "max_length"
        return tokenizer(
            text,
            return_tensors="pt",
            return_token_type_ids=False,
            truncation=True,
            padding=padding,
            max_length=max_length,
        )


    # %%
    map_tok = {}
    import string
    for i in string.ascii_lowercase:
        map_tok[i] = tokenizer(i)['input_ids'][0]

    def doCharLevelTokenization(input_string):
        custom_input = transformers.tokenization_utils_base.BatchEncoding()
        custom_input['input_ids'] = []
        custom_input['attention_mask'] = []
        for i in input_string:
            custom_input['input_ids'].append(map_tok[i])
            custom_input['attention_mask'].append(1)
        custom_input['input_ids'] = torch.tensor(custom_input['input_ids']).unsqueeze(0)
        custom_input['attention_mask'] = torch.tensor(custom_input['attention_mask']).unsqueeze(0)
        return custom_input

    # %%
    import numpy as np
    import string

    # %%
    dataset = generate_random_strings(seed_id=PRETRAINED_CKPT_SEED_ID)
    encoded_dataset = encode_character_wise(tokenizer, dataset)

    # %%
    training_dataset = encoded_dataset.remove_columns(["text"])
    logger.info(f"Generated {len(encoded_dataset['test'])} sequences")

    # %%
    training_dataset

    # %%
    from transformers import DataCollatorForLanguageModeling

    tokenizer.pad_token = tokenizer.eos_token
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    from transformers import AutoModelForCausalLM, TrainingArguments, Trainer

    model = AutoModelForCausalLM.from_pretrained(PRE_TRAINING_CHECKPOINT)

    training_args = TrainingArguments(
        output_dir = OUTPUT_DIR,
        evaluation_strategy = EVALUATION_STRATEGY,
        learning_rate = LEARNING_RATE,
        lr_scheduler_type = LR_SCHEDULER,
        warmup_ratio = WARMUP_RATIO,
        num_train_epochs = NUM_TRAIN_EPOCHS,
        save_strategy = SAVE_STRATEGY,
        run_name = RUN_NAME,
        report_to=["wandb"],
        # push_to_hub=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=training_dataset["train"],
        eval_dataset=training_dataset["test"],
        data_collator=data_collator,
        # resume_from_checkpoint = True,
        # callbacks=[NextTokenProbabilityCallback(dataset["test"].to_pandas()['text'].to_list(), OUTPUT_DIR)],
    )

    trainer.train()

    # %%
    import math

    eval_results = trainer.evaluate()
    print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

    # %%
    inp_strs = dataset['train'].to_pandas()['text'].to_list()

    # %%
    str_save_path = 'custom_ft/strings/'

    # clear the file
    open(str_save_path + f'inp_str_{RUN_NAME}.txt', 'w').close()

    # save inp_str in a text file
    with open(str_save_path + f'inp_str_{RUN_NAME}.txt', 'w') as f:
        # write all strings to file
        for item in inp_strs:
            f.write("%s\n" % item)

    wandb.finish()

