import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from pandarallel import pandarallel

import matplotlib.pyplot as plt
import sys
import os
import glob


import pandas as pd
from datasets import Dataset
from dotenv import load_dotenv
# import wandb
from tabulate import tabulate

load_dotenv()

os.environ.get('HF_TOKEN')

pandarallel.initialize()


tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")

model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-125m").to("cuda")


print("Load Data...")

combined_data = pd.read_parquet('PATH/tdc_dataset.parquet')
print("Dataset size: ", len(combined_data))
random_rows = combined_data.sample(n=5)
print(tabulate(random_rows, headers='keys', tablefmt='pretty', showindex=False))

# combined_data = combined_data.dropna()
print("Dataset size: ", len(combined_data))

dataset = Dataset.from_pandas(combined_data)

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))

tokenizer.padding_side = 'left'


def tokenize_function(examples):
    # Generate text for each dataset ID
    full_text = []

    for i in range(len(examples['dataset_id'])):
        if examples['dataset_id'][i] == 'oled':
            # Text for 'oled' dataset
            full_text.append(
                f"[WAVELENGTH]{examples['wavelength'][i]}[/WAVELENGTH]"
                f"[F_OSC]{examples['f_osc'][i]}[/F_OSC][SEP]"
                f"[START_SMILES]{examples['smiles'][i]}[END_SMILES]"
            )
        elif examples['dataset_id'][i] == 'tdc':
            # Text for 'tdc' dataset
            full_text.append(
                f"</s>[QED]{examples['QED'][i]}[/QED][LOGP]{examples['LOGP'][i]}[/LOGP]"
                f"[START_SMILES]{examples['smiles'][i]}[END_SMILES]"
            )
        else:
            raise ValueError(f"Unexpected Dataset ID: {examples['dataset_id'][i]}")

    # Tokenize the batch of text
    tokenized_full = tokenizer(
        full_text,
        truncation=True,
        max_length=512,
        padding='max_length'
    )

    # Set 'labels' for next-token prediction
    tokenized_full['labels'] = tokenized_full['input_ids']

    return tokenized_full

print("Tokenize Dataset")
tokenized_dataset = dataset.map(tokenize_function, batched=True)

train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=21)

train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=False
)

print("Set training args")
training_args = TrainingArguments(
    output_dir="./results/tdc",
    evaluation_strategy="steps",
    eval_steps = 2000,
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    max_grad_norm=1.0,
    report_to=["none"],
    # logging_dir='./wandb',
    logging_steps=2000
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

print("Training")

trainer.train()
print("Saving")

model.save_pretrained("./fine_tuned_chemlactica_tdc")
tokenizer.save_pretrained("./fine_tuned_chemlactica_tdc")
