import torch
import transformers
from peft import LoraConfig, TaskType, get_peft_model
from torch.utils.data import DataLoader

from tqdm import tqdm

from llama3 import model, tokenizer, EOS
from data import get_loaders

@torch.no_grad()
def validation(vali_model: transformers.PreTrainedModel, vali_loader: DataLoader):
    
    from math import e
    total_loss, total_ppl, length = 0., 0., 0
    
    vali_model.eval()
    for entry in tqdm(vali_loader, desc="Validating", leave=False):
        texts = entry["Text"]
        for text in texts:
            text = text + EOS
        
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(vali_model.device)
        labels = inputs.input_ids
        
        outputs = vali_model(**inputs, labels=labels)
        
        total_loss += outputs.loss
        total_ppl += e ** outputs.loss
        length += len(texts)
    
    avg_loss = total_loss / length
    avg_ppl = total_ppl / length
    return avg_loss, avg_ppl

def train(train_model: transformers.PreTrainedModel, train_loader: DataLoader, 
          optimizer: torch.optim.Optimizer, 
          schedular: torch.optim.lr_scheduler.LRScheduler):
    
    loss_func = torch.nn.CrossEntropyLoss()
    train_model.train()
    for entry in tqdm(train_loader, desc="Training", leave=True):
        texts = entry["Text"]
        for text in texts:
            text = text + EOS

        inputs = tokenizer(text, return_tensors="pt", padding=True).to(train_model.device)
        length = inputs.attention_mask.sum(dim=-1)

        outputs = train_model(**inputs)
        lg, gd = [], []
        for lgs, ids, l in zip(outputs.logits, inputs.input_ids, length.cpu()):
            lg.append(lgs[0:l - 1])
            gd.append(ids[1:l])
        lg = torch.cat(lg, dim=0)
        gd = torch.cat(gd, dim=0)

        loss = loss_func(lg, gd)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        schedular.step()



def main():
    global model
    device = torch.device("cuda")
    batch_size = 4
    lr = 1e-5
    epochs = 2
    trainloader, validloader = get_loaders(batch_size)

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                        "up_proj", "down_proj", "gate_proj"],
        lora_dropout=0.01,
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    optimizer = torch.optim.AdamW(model.parameters(), lr, eps=2e-4)
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)

    print(validation(model, validloader))
    for i in range(epochs):
        print(f"Epoch {i + 1} / {epochs}...")
        train(model, trainloader, optimizer, scheduler)
        print(validation(model, validloader))
        model.save_pretrained(f"Tiny-Shakespear-r16-all-{i}")



if __name__ == "__main__":
    main()
