import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
from IPython.display import display
from transformers import LlavaForConditionalGeneration, AutoProcessor
import traceback
import random
import numpy as np
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, PeftConfig
from utils import *
from transformers import TrainingArguments, Trainer, default_data_collator
import os
import argparse
import sys
from pathlib import Path
import json
import re
from datasets import config
import json

# Set cache directory for datasets
config.HF_DATASETS_CACHE = "../hf_cache"

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='name of trainning save dir of')
parser.add_argument('--no_log', action="store_true", help='log all printout to printout.txt')
args = parser.parse_args()

if args.name is None:
    print("Please provide the trainning save dir by --name [name]")
    exit()    

from_checkpoint = False

run_config = {}
today_dir = args.name
eval_steps = 1000
save_steps = 2000  # Commented out - using eval steps
val_ds_num = 2000
metric_ds_num = 600
train_ds_num = 160000

epochs = 4
similarity_threshold = 0.06
think_mode = False
enhance_mode = False
text_enhance_mode = True  # New text enhance mode
lora_rank = 32
early_stopping_patience = 10000  # Early stopping after 10000 steps without improvement
 

if (not args.no_log):
    # Create parent directories if needed
    Path('tinyllava-lora/'+today_dir).mkdir(parents=True, exist_ok=True)

    # save_config = 

    config_save = {
        'today_dir': today_dir,
        'train_ds_num': train_ds_num,
        'epochs': epochs,
        'similarity_threshold': similarity_threshold,
        "think_mode": think_mode,
        "enhance_mode" : enhance_mode,
        "text_enhance_mode": text_enhance_mode,
        "lora_rank" : lora_rank,
        "early_stopping_patience": early_stopping_patience
    }

    with open('tinyllava-lora/'+today_dir+"/config.json", "w") as f:
        json.dump(config_save, f, indent=4) 

    log_file = open('tinyllava-lora/'+today_dir+"/printout.txt", "a")
    sys.stdout = Tee(sys.__stdout__, log_file)


if (from_checkpoint):
    # disabled for now
    # pretained_path = None
    # # Load the full merged model (LoRA already merged into base)
    # model = LlavaForConditionalGeneration.from_pretrained(
    #     pretained_path,
    #     torch_dtype=torch.float16,
    #     device_map="cuda",
    # )
    pass
    
else:

    model_id = "bczhou/tiny-llava-v1-hf"

    # Load the modified model
    model = CustomLlavaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="cuda"
    )

    processor = AutoProcessor.from_pretrained(model_id)
    processor.patch_size = model.config.vision_config.patch_size
    print("✓ Tiny‑LLaVA loaded in FP16 on A100")

    model.set_token_mixer_processor(processor)


model.token_mixer.similarity_threshold = similarity_threshold
model.token_mixer.think_mode = think_mode
model.token_mixer.enhance_mode = enhance_mode
model.token_mixer.text_enhance_mode = text_enhance_mode

#########################################################

target_modules = get_target_models(model)
print(target_modules)

# 2. ***patch the flag***
model.config.is_encoder_decoder = False          # <- crucial line

model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank*2,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
)

model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# 2 – double‑check the special token id once
model.config.image_token_index = processor.tokenizer.convert_tokens_to_ids("<image>")


#########################################################


# load_data = asyncio.run(load_llava_instruct_150k(
#     processor=processor,
#     max_samples=max_samples,
#     seed=42,
#     num_samples_to_show=3,
#     batch_size=10
# ))

# split_ratio = 0.98
# metric_test_data_percetage = 0.4

# train_ds_num, val_ds_num, metric_ds_num = calculate_dataset_sizes(dataset_size, split_ratio, metric_test_data_percetage)

seed = 126

train_ds_iter = load_dataset(
        "liuhaotian/LLaVA-Instruct-150K",
        split="train",
        streaming=True
    ).shuffle(seed=seed, buffer_size=1000).skip(max(val_ds_num,metric_ds_num) ).take(train_ds_num)

val_ds_iter = load_dataset(
        "liuhaotian/LLaVA-Instruct-150K",
        split="train",
        streaming=True
    ).shuffle(seed=seed, buffer_size=1000).take(val_ds_num)

metric_ds_iter = load_dataset(
        "liuhaotian/LLaVA-Instruct-150K",
        split="train",
        streaming=True
    ).shuffle(seed=seed, buffer_size=1000).take(metric_ds_num)

train_data = load_llava_instruct_150k_local(
    train_ds_iter, 
    processor=processor,
    num_samples_to_show=3,
    image_dir = "../coco/train2017/"
)

val_data = load_llava_instruct_150k_local(
    val_ds_iter,
    processor=processor,
    num_samples_to_show=2,
    image_dir = "../coco/train2017/"
)


print("dataset loaded")
print("train_data:", len(train_data))
print("val_data:", len(val_data))
print("***********************************************\n\n")



# Example BLEU calculation
example_predictions = ["Cat is on mat"]
example_references = [["The cat is sitting on the mat"]]
example_predictions = [text.lower() for text in example_predictions]
example_references = [[ref.lower() for ref in refs] for refs in example_references]
nltk_bleu = sentence_bleu(
    references=[ref[0].split() for ref in example_references],
    hypothesis=example_predictions[0].split(),
    weights=(0.5, 0.5),
    smoothing_function=SmoothingFunction().method1
)
print(f"Example BLEU Score: {nltk_bleu:.4f}, Should be ~0.2727")



# # test running validation BLEU
# torch.cuda.empty_cache()
# bleu_score = compute_bleu(
#     model=model,
#     processor=processor,
#     val_data=metric_ds_iter,
#     max_new_tokens=128,
#     do_sample=False
# )
# print(f"Test Sample BLEU Score: {bleu_score:.4f}")



os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

# Clear GPU memory
torch.cuda.empty_cache()

training_args = TrainingArguments(
    output_dir="tinyllava-lora/"+today_dir,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=epochs,  # Increased for demonstration
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio = 0.03,
    fp16=True,
    bf16=False,
    logging_steps=1000,
    save_steps=save_steps,  # Save checkpoint every eval_steps
    save_total_limit=None,  
    report_to="none",
    remove_unused_columns=False,
    eval_strategy="steps",
    eval_steps=eval_steps,
    per_device_eval_batch_size=1,
    weight_decay=1e-3,
)

# Create collator with tokenizer's pad token ID
collator = CustomDataCollator(pad_token_id=processor.tokenizer.pad_token_id)

# Create the best checkpoint tracker
best_checkpoint_tracker = BestCheckpointTracker(early_stopping_patience=early_stopping_patience)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,  # Use a subset of training data for evaluation (or replace with val_data)
    data_collator=collator,
    val_data=metric_ds_iter,  # Pass the validation dataset
    processor=processor,  # Pass the processor
    callbacks=[best_checkpoint_tracker],  # Add the tracker callback
)

print("*************Start Training*************")
print(f"Evaluation every {eval_steps} steps")
print(f"💾 Saving checkpoints every {eval_steps} steps")
print(f"📊 Tracking best checkpoint during training")
print(f"⏱️ Early stopping patience: {early_stopping_patience} steps")

# Train the model
trainer.train()

# Get the best checkpoint info from the tracker
best_checkpoint_info = best_checkpoint_tracker.get_best_checkpoint_info()
best_checkpoint_step = best_checkpoint_info['best_checkpoint_step']

if best_checkpoint_step is None:
    print("⚠️ No best checkpoint found, using the last checkpoint")
    peft_path = "tinyllava-lora/" + today_dir
else:
    print(f"✅ Best checkpoint found at step {best_checkpoint_step}")
    peft_path = f"tinyllava-lora/{today_dir}/checkpoint-{best_checkpoint_step}"

# Load the best checkpoint using PEFT's proper loading method
print(f"🔄 Loading best checkpoint from: {peft_path}")
checkpoint_loaded = False


# Check if the checkpoint path exists and contains PEFT files
if os.path.exists(peft_path) and os.path.exists(os.path.join(peft_path, "adapter_config.json")):
    # Load PEFT config to get base model name
    peft_config = PeftConfig.from_pretrained(peft_path)
    base_model_name_or_path = peft_config.base_model_name_or_path
    
    print(f"📋 PEFT config loaded from: {peft_path}")
    print(f"🏗️  Base model: {base_model_name_or_path}")
    
    # Load base model
    base_model = CustomLlavaForConditionalGeneration.from_pretrained(
        base_model_name_or_path,
        torch_dtype=torch.float16,
        device_map="cuda"
    )
    base_model.set_token_mixer_processor(processor)
    
    # Load LoRA adapters using PEFT's proper method
    best_model = PeftModel.from_pretrained(base_model, peft_path)
    best_model.config.is_encoder_decoder = False
    best_model.config.image_token_index = processor.tokenizer.convert_tokens_to_ids("<image>")
    
    print(f"✅ Successfully loaded best checkpoint using PEFT: {peft_path}")
    checkpoint_loaded = True
    
else:
    print(f"❌ PEFT checkpoint not found at: {peft_path}")
    print(f"📁 Available files in {peft_path}:")
    if os.path.exists(peft_path):
        for file in os.listdir(peft_path):
            print(f"   - {file}")
    else:
        print(f"❌ Directory does not exist: {peft_path}")
            


if not checkpoint_loaded:
    print("⚠️ No checkpoint loaded - using current model state for merging")
    best_model = model

# merge LoRA into the base weights for standalone inference
print("🔗 Merging LoRA weights into base model...")
merged_model = best_model.merge_and_unload()
merged_path = "tinyllava-lora/"+today_dir+"/merged"
merged_model.save_pretrained(merged_path)
processor.save_pretrained(merged_path)

# Delete all checkpoints except the best one to save disk space
print("🗑️ Cleaning up checkpoints...")
checkpoint_dir = f"tinyllava-lora/{today_dir}"
# if os.path.exists(checkpoint_dir) and best_checkpoint_step is not None:
#     best_checkpoint_name = f"checkpoint-{best_checkpoint_step}"
#     best_checkpoint_path = os.path.join(checkpoint_dir, best_checkpoint_name)
    
#     # Safety check: make sure the best checkpoint exists before deleting others
#     if not os.path.exists(best_checkpoint_path):
#         print(f"⚠️ Warning: Best checkpoint {best_checkpoint_name} not found, skipping cleanup")
#     else:
#         deleted_count = 0
#         total_deleted_size = 0
        
#         for item in os.listdir(checkpoint_dir):
#             item_path = os.path.join(checkpoint_dir, item)
#             if item.startswith("checkpoint-") and os.path.isdir(item_path):
#                 if item != best_checkpoint_name:
#                     # Calculate size before deleting
#                     import shutil
#                     size = sum(f.stat().st_size for f in os.scandir(item_path) if f.is_file())
#                     total_deleted_size += size
                    
#                     print(f"   Deleting: {item} ({size / 1024 / 1024:.1f} MB)")
#                     shutil.rmtree(item_path)
#                     deleted_count += 1
#                 else:
#                     print(f"   Keeping: {item} (best checkpoint)")
        
        
#         print(f"✅ Checkpoint cleanup completed - deleted {deleted_count} checkpoints")
#         print(f"💾 Freed up {total_deleted_size / 1024 / 1024:.1f} MB of disk space")
# else:
#     print("⚠️ No checkpoints to clean up or best checkpoint not found")

print(f"🚀 All done – fine‑tuned Tiny‑LLaVA saved to: {merged_path}")
print(f"📊 Best evaluation loss achieved: {best_checkpoint_info['best_eval_loss']:.4f} at step {best_checkpoint_step}")

# Print training summary
print("\n" + "="*60)
print("🎯 TRAINING SUMMARY")
print("="*60)
print(f"📁 Output directory: {today_dir}")
print(f"🏆 Best evaluation loss: {best_checkpoint_info['best_eval_loss']:.4f}")
print(f"📍 Best loss at step: {best_checkpoint_step}")
print(f"📊 Evaluation frequency: {eval_steps} steps")
print(f"⏱️ Early stopping patience: {early_stopping_patience} steps")
print(f"💾 Best checkpoint: {peft_path}")
print(f"🔗 Merged model: {merged_path}")
print("="*60)

