import os
import torch
import logging
import sys
import warnings
from datetime import datetime
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    set_seed,
    TrainerCallback,
)
from datasets import load_dataset
import datasets
import math


datasets.logging.set_verbosity_error()
logging.getLogger("datasets").setLevel(logging.CRITICAL)


warnings.filterwarnings("ignore", message=".*use_cache=True.*incompatible with gradient checkpointing.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint")


log_dir = "./Others/logs"
os.makedirs(log_dir, exist_ok=True)


current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_{current_time}.log")


logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[
        logging.FileHandler(log_file, mode="a", encoding="utf-8"),
        logging.StreamHandler(sys.stdout),
    ],
)
logger = logging.getLogger(__name__)


HYPERPARAMS = {
    "model_path": "./Others/TinyLlama-1.1B-Chat-v1.0",
    "train_file": "./Fine_tunning/Data_processing_and_data/train.json",
    "val_file": "./Fine_tunning/Data_processing_and_data/val.json",
    "test_file": "./Fine_tunning/Data_processing_and_data/test.json",
    "output_dir": "./Others/tinyllama-finetuned",
    "final_model_dir": "./Others/final_model",
    "batch_size": 16,
    "gradient_accumulation_steps": 2,
    "learning_rate": 2e-5,
    "num_epochs": 5,
    "warmup_steps": 200,
    "max_length": 2048,
    "weight_decay": 0.01,
}


class EpochLossCallback(TrainerCallback):
    def __init__(self):
        self.epoch_losses = []
        self.current_epoch = 0

    def on_log(self, args, state, control, logs=None, **kwargs):

        if logs and "loss" in logs:
            self.epoch_losses.append(logs["loss"])

    def on_epoch_end(self, args, state, control, **kwargs):

        if self.epoch_losses:
            avg_train_loss = sum(self.epoch_losses) / len(self.epoch_losses)
            epoch = math.ceil(state.epoch)
            logger.info(f"Epoch {epoch} - Training loss: {avg_train_loss:.6f}")
            self.epoch_losses = []
            self.current_epoch = epoch

    def on_evaluate(self, args, state, control, metrics, **kwargs):

        if "eval_loss" in metrics:
            logger.info(f"Epoch {self.current_epoch} - Validation loss: {metrics['eval_loss']:.6f}")

def setup_training():


    set_seed(42)


    logger.info(f"Training log saved to: {log_file}")
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"GPU available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")


    logger.info("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(HYPERPARAMS["model_path"])
    tokenizer.pad_token = tokenizer.eos_token


    def tokenize_function(examples):
        messages = examples["messages"]
        texts = [tokenizer.apply_chat_template(msg, tokenize=False) for msg in messages]
        tokenized = tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=HYPERPARAMS["max_length"],
            return_tensors="pt",
        )
        tokenized["labels"] = tokenized["input_ids"].clone()
        return tokenized


    logger.info("Loading dataset...")
    data_files = {
        "train": HYPERPARAMS["train_file"],
        "validation": HYPERPARAMS["val_file"],
        "test": HYPERPARAMS["test_file"],
    }
    raw_datasets = load_dataset("json", data_files=data_files)


    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=["messages"],
        desc="Tokenizing datasets",
    )

    logger.info(f"Training set size: {len(tokenized_datasets['train'])}")
    logger.info(f"Validation set size: {len(tokenized_datasets['validation'])}")
    logger.info(f"Test set size: {len(tokenized_datasets['test'])}")


    logger.info("Loading pretrained model...")
    model = AutoModelForCausalLM.from_pretrained(
        HYPERPARAMS["model_path"],
        torch_dtype=torch.float32,
    )
    model.gradient_checkpointing_enable()

    return model, tokenizer, tokenized_datasets

def train_model(model, tokenizer, tokenized_datasets):


    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


    batch_size = HYPERPARAMS["batch_size"]
    gradient_accumulation_steps = HYPERPARAMS["gradient_accumulation_steps"]
    num_epochs = HYPERPARAMS["num_epochs"]
    steps_per_epoch = max(1, len(tokenized_datasets["train"]) // (batch_size * gradient_accumulation_steps))
    total_steps = steps_per_epoch * num_epochs

    logger.info(f"Steps per epoch: {steps_per_epoch}")
    logger.info(f"Total training steps: {total_steps}")


    training_args = TrainingArguments(
        output_dir=HYPERPARAMS["output_dir"],
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=HYPERPARAMS["learning_rate"],
        num_train_epochs=num_epochs,
        weight_decay=HYPERPARAMS["weight_decay"],
        save_strategy="epoch",
        save_total_limit=2,
        logging_steps=steps_per_epoch,
        eval_strategy="epoch",
        warmup_steps=HYPERPARAMS["warmup_steps"],
        fp16=False,
        report_to="none",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        gradient_checkpointing=True,
    )


    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        callbacks=[EpochLossCallback()],
    )


    logger.info("Starting training...")
    trainer.train()


    logger.info("Saving final model...")
    trainer.save_model(HYPERPARAMS["final_model_dir"])
    tokenizer.save_pretrained(HYPERPARAMS["final_model_dir"])
    logger.info(f"Model saved to: {HYPERPARAMS['final_model_dir']}")

    return trainer

def evaluate_model(trainer, tokenizer, tokenized_datasets):

    logger.info("Starting model evaluation...")
    eval_results = trainer.evaluate(eval_dataset=tokenized_datasets["test"])
    logger.info(f"Test set loss: {eval_results['eval_loss']:.6f}")
    logger.info(f"Test set perplexity: {torch.exp(torch.tensor(eval_results['eval_loss'])):.6f}")


    logger.info("Generating text examples...")
    model = trainer.model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    test_inputs = [
        "1021, 0.1154, 0.1317, 0.1479, 0.1391, 0.1420, 0.1967, 0.1864, 0.1908, 0.2115, 0.1893, 0.1893, 0.2352, 0.2707, 0.2678, 0.3136, 0.2899, 0.3861, 0.4379, 0.4512, 0.4201, 0.4571, 0.4689, 0.5577, 0.5754, 0.6198, 0.7604, 0.7278, 0.8802, 0.9083, 0.9053, 0.8994, 0.8876, 0.8299, 0.7234, 0.6731, 0.7441, 0.7574, 0.8107, 0.5843, 0.5947, 0.6967, 0.5666, 0.5533, 0.4305, 0.5355, 0.5976, 0.5518, 0.5533, 0.6997, 0.4822, 0.7012, 0.5680, 0.6953, 0.5947, 0.5947, 0.5636, 0.5799, 0.5621, 0.6317, 0.6036, 0.6213, 0.6435, 0.6524, 0.7559, 0.6479, 0.6938, 0.6361, 0.6095, 0.6021, 0.5873, 0.5503, 0.5488, 0.4970, 0.5059, 0.5710, 0.5459, 0.5607, 0.5296, 0.4985, 0.5355, 0.5000, 0.5207, 0.5296, 0.5547, 0.5592, 0.5296, 0.5518, 0.5266, 0.5651, 0.5592, 0.5399, 0.5680, 0.5577, 0.5074, 0.5414, 0.5947, 0.5429, 0.4985, 0.5148, 0.5976, 0.5533, 0.5740, 0.5799, 0.5414, 0.5695, 0.5296, 0.5740, 0.5399, 0.5710, 0.5814, 0.6139, 0.5666, 0.5251, 0.5932, 0.5902, 0.5325, 0.5607, 0.5192, 0.6686, 0.6820, 0.6686, 0.6257, 0.5843, 0.6243, 0.6938, 0.7101, 0.7322, 0.6376, 0.6346, 0.6583, 0.7041, 0.7781, 0.7633, 0.6494, 0.7485, 0.7263, 0.8358, 0.8462, 0.8698, 0.7618, 0.7441, 0.7766, 0.7944, 0.8861, 0.8609, 0.8047, 0.7027, 0.8314, 0.8151, 0.8772, 0.8269, 0.8358, 0.7929, 0.7441, 0.8121, 0.7396, 0.5695, 0.6509, 0.5740, 0.6080, 0.6346, 0.5843, 0.6361, 0.6228, 0.6213, 0.5858, 0.6405, 0.6598, 0.6716, 0.5888, 0.6095, 0.6183, 0.6272, 0.6154, 0.5577, 0.5370, 0.5281, 0.4689, 0.4349, 0.4527, 0.4216, 0.3964, 0.3920, 0.3905, 0.3447, 0.3373, 0.3580, 0.3151, 0.3018, 0.3092, 0.3299, 0.3920, 0.3254, 0.2988, 0.3018, 0.3284, 0.3018, 0.3033, 0.3062, 0.2530, 0.2944, 0.2589, 0.2870, 0.3314, 0.3092, 0.2840, 0.3033, 0.2678, 0.3003, 0.2559, 0.2145, 0.2456, 0.2130, 0.2515, 0.2041, 0.2367, 0.2278, 0.1879, 0.1642, 0.1494, 0.1701, 0.1612, 0.1760, 0.1405, 0.1331, 0.1272, 0.1243, 0.1213, 0.0976, 0.1154, 0.1213, 0.1154, 0.0888, 0.1139, 0.0962, 0.0740, 0.0725, 0.0651], Predict the traffic flow in the next 12 time steps.",
    ]
    for test_input in test_inputs:
        inputs = tokenizer(test_input, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
            )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logger.info(f"Input: {test_input}")
        logger.info(f"Output: {generated_text}")
        logger.info("-" * 40)

def main():

    logger.info("=" * 50)
    logger.info("Starting TinyLlama fine-tuning training")
    logger.info("=" * 50)


    model, tokenizer, tokenized_datasets = setup_training()


    trainer = train_model(model, tokenizer, tokenized_datasets)


    evaluate_model(trainer, tokenizer, tokenized_datasets)

    logger.info("=" * 50)
    logger.info("Training and evaluation completed")
    logger.info("=" * 50)

if __name__ == "__main__":
    main()