import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
import random
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from torch.utils.data import DataLoader
import time
import torch
from sacrebleu import corpus_bleu
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from peft import LoraConfig, get_peft_model, TaskType
from pytorch_lightning.callbacks import LearningRateMonitor
from finetune_model import TextCompressor
from data_finetune_ import WikiTextDataModule, TextReconstructionDataModule, WikimediaDataModule,WikimediaRobustFineTuneDataModule, Coco2017DataModule, PixmoCapDataModule, WikiTextDataset
import pickle
from transformers import AutoTokenizer

torch.set_float32_matmul_precision('medium')
def set_seed(seed: int = 42):
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pl.seed_everything(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def test_dataloader_speed(dataloader: DataLoader, num_batches: int = 100):
    start_time = time.time()
    for i, batch in enumerate(dataloader):
        if i >= num_batches:
            break
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_batches
    print(f"Average batch loading time: {avg_time*1000:.2f}ms")

def train():

    MAX_LENGTH = 30
    set_seed(42)

    # Initialize data 
    #WikiText a la place de 2017
    data_module = WikiTextDataModule(
        batch_size=256,  # Reduced from 256
        max_length=MAX_LENGTH,
        num_workers=min(24, os.cpu_count()-1)  # Reduced from 16#
    )


    # Initialize TextCompressor with ModernBERT
    model = TextCompressor(
        vocab_size=data_module.vocab_size,
        latent_dim=256,
        hidden_dim=512, #768
        num_layers=10,#8
        num_heads=8,
        dropout=0.0,
        pooling_strategy="cls", # <-- "mean" or "cls"
        teacher_forcing_start_ratio=0.99,
        teacher_forcing_end_ratio=0.99,
        tokenizer=data_module.tokenizer,  # Pass the tokenizer
        lr = 3e-4,
        new_lr=2e-6, # Add new learning rate here
        scheduler_type="cosine", 
        noise_sigma = 0.0125,
        max_length = MAX_LENGTH,
    )
    
    # we wrap model.modern_bert with PEFT. For example:
    peft_config = LoraConfig(
        target_modules=["attn.Wqkv", "attn.Wo", "mlp.Wi", "mlp.Wo"],
        r=32, #16
        lora_alpha=64 #32
    )

    print("trying to freeze modernBert")
    # Wrap the existing ModernBERT encoder with LoRA
    model.modern_bert = get_peft_model(model.modern_bert, peft_config)
    for name, param in model.modern_bert.named_parameters():
        if 'lora' not in name.lower():
            param.requires_grad = False
            
    # Setup checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        monitor='val_bleu',
        dirpath='/home/user/autoencoder_ckpts/wikitext_ft',
        filename='text-compressor--robust_0.0125--cls--30tok-wikitext_ft-{epoch:02d}-{val_bleu:.2f}',
        save_top_k=1,
        mode='max',
        #every_n_epochs=10,  # Save every 5 epochs
        save_last=True     # Additionally save the last checkpoint
    )
    
    # Setup logger
    wandb_logger = WandbLogger(project='text-compressor')
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=50,
        accelerator='gpu',
        devices=1,
        strategy=DDPStrategy(find_unused_parameters=True),
        callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
        logger=wandb_logger,
        #precision="16-mixed",
        gradient_clip_val=1.0,
        val_check_interval=1.0,

    )
    
    def load_weights_only(model, ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint, strict=False)
        return model
    
    # Use this function before training to restart from checkpoint
    # model = load_weights_only(model, ckpt_path)

    # --- Define the path to your curated test set ---
    CHALLENGE_SET_FILENAME = "challenge_set_top_5percent.pkl"
    BATCH_SIZE = 32
    NUM_WORKERS = 4
    MAX_LENGTH = 30
    
    print(f"Loading challenge test set from: {CHALLENGE_SET_FILENAME}")
    
    # --- Load the data from the .pkl file ---
    with open(CHALLENGE_SET_FILENAME, 'rb') as f:
        challenge_data = pickle.load(f)
    
    # --- We need a tokenizer. Let's create one directly. ---
    # (This avoids needing the DataModule at all)
    tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
    special_tokens = {
        'pad_token': '[PAD]', 'unk_token': '[UNK]',
        'bos_token': '[CLS]', 'eos_token': '[SEP]'
    }
    tokenizer.add_special_tokens(special_tokens)
    
    # challenge_dataset = WikiTextDataset(
    #     data=challenge_data, 
    #     tokenizer=tokenizer, 
    #     max_length=MAX_LENGTH
    # )
    
    # challenge_dataloader = DataLoader(
    #     challenge_dataset, 
    #     batch_size=BATCH_SIZE, 
    #     num_workers=NUM_WORKERS
    # )
    
    
    # # ===================================================================
    # # 3. RUN THE TEST
    # # ===================================================================
    
    # print("\n--- Starting test on the Challenge Set ---")
    
    # # Pass the dataloader directly to the trainer.test() method
    # trainer.test(model=model, dataloaders=challenge_dataloader)
    
    # print("\n--- Test complete! ---")
    trainer.fit(model, data_module)#, ckpt_path=ckpt_path)
    # 

if __name__ == '__main__':
    train()
