
import os
import time
import torch
import argparse
import transformers
from tqdm import trange
from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType

def run_peft(args, verbose=False):

    model_id = "./llama3.1-8B"
    pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
    model = pipeline.model
    tokenizer = pipeline.tokenizer
    tokenizer.pad_token_id = tokenizer.eos_token_id

    K = 1
    peft_config = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        prompt_tuning_init=PromptTuningInit.RANDOM,
        num_virtual_tokens=K,
        tokenizer_name_or_path=model_id,
    )
    model = get_peft_model(model, peft_config)

    # Set the device to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Set input ids
    train_text = "<|begin_of_text|>" + args.text + ("<|end_of_text|>" * 3)
    train_ids = tokenizer(train_text, return_tensors="pt").input_ids.to(device)
    train_mask = torch.ones_like(train_ids).to(device)
    max_gen_tokens = train_ids.shape[1] + 10

    # Init optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    params[0].data = params[0].data / params[0].data.norm() * args.norm
    optimizer = transformers.AdamW(params, lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

    # Training loop
    start = time.time()
    prev_valid, best_loss = -100, float('inf')
    loss_fn = torch.nn.CrossEntropyLoss()
    progress_bar = trange(args.max_steps) if verbose else range(args.max_steps)
    for i in progress_bar:
        model.train()
        
        # Forward pass: Pass the input through the model
        outputs = model(train_ids, attention_mask=train_mask)
        logits = outputs.logits[:, K:-1].contiguous().squeeze()
        labels = train_ids[:, 1:].contiguous().squeeze()
        loss = loss_fn(logits, labels)
        if verbose:
            progress_bar.set_postfix({
                "loss": f"{loss:.4f}",
                "lr": f"{optimizer.param_groups[0]['lr']:.6f}"
            })

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Enforce fixed param norm
        params[0].data = params[0].data / params[0].data.norm() * args.norm

        # Check loss status
        if loss.item() < best_loss:
            best_loss = loss.item()
            if (loss.item() <= args.loss_thresh) and (i - prev_valid >= args.valid_patience):
                model.eval()
                with torch.no_grad():
                    valid_ids = tokenizer("<|begin_of_text|>", return_tensors="pt").input_ids.to(device)
                    valid_mask = torch.ones_like(valid_ids).to(device)
                    for _ in range(max_gen_tokens):
                        outputs = model(valid_ids, attention_mask=valid_mask)
                        next_token_logits = outputs.logits[:, -1, :]
                        next_token = torch.argmax(next_token_logits, dim=-1)
                        valid_ids = torch.cat([valid_ids, next_token.unsqueeze(-1)], dim=-1)
                        valid_mask = torch.ones_like(valid_ids).to(device)
                    generated_text = tokenizer.decode(valid_ids[0], skip_special_tokens=True)
                    # print(f"Generated text: {generated_text}")
                if args.text in generated_text:
                    # print(f"Exact match found at step {i+1}")
                    model.save_pretrained(args.save_dir)
                    info_path = os.path.join(args.save_dir, 'info.txt')
                    with open(info_path, 'w') as f:
                        f.write(f'steps {i+1} time {time.time()-start}')
                    return (i+1)
    return None

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--norm', type=float, default=40.0)
    parser.add_argument('--max_steps', type=int, default=2000)
    parser.add_argument('--loss_thresh', type=float, default=0.015)
    parser.add_argument('--valid_patience', type=int, default=100)
    parser.add_argument('--num_tokens', type=int, default=256)
    parser.add_argument('--start_idx', type=int, default=0)
    parser.add_argument('--end_idx', type=int, default=500)

    args = parser.parse_args()
    for idx in range(args.start_idx, args.end_idx):
        data_path = f'../data/tok{args.num_tokens}/{idx}.txt'
        with open(data_path, 'r') as f:
            args.text = f.read()
        args.save_dir = f'./ckpt/tok{args.num_tokens}/{idx}'   
        steps = run_peft(args)
        # steps = run_peft(args, verbose=True)
        print(f'==> [tok{args.num_tokens}] idx {idx} | steps {steps}')
