import argparse
import os
from transformers import T5Tokenizer, T5ForConditionalGeneration, LlamaTokenizer, LlamaForCausalLM
from val_modules import ValidationDataModule
import torch
import pandas as pd
from datasets import load_from_disk
from tqdm import tqdm

def validate(model, dataloader):
    
    val_losses = []
    
    model = model.to("cuda")
    model.eval()
    for batch in tqdm(dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        val_losses.append(loss.item())
    
    model = model.to("cpu")
    
    return sum(val_losses) / len(val_losses)



def run(args):
    
    results = []
    
    for epoch in args.valid_epochs:
        assert os.path.exists(os.path.join(args.model_dir, "epoch-{}".format(epoch))), "Model for epoch {} does not exist in {}".format(epoch, args.model_dir)
    
    tokenizer_path = os.path.join(args.model_dir, "epoch-{}".format(args.valid_epochs[0]))
    isAlpaca = "alpaca" in tokenizer_path or "llama" in tokenizer_path
    if isAlpaca:
        tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, padding_side="left")
        ModelClass = LlamaForCausalLM
    else:
        tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
        ModelClass = T5ForConditionalGeneration
        
    data_module = ValidationDataModule(args.batch_size, args.iid_data_dir, args.benchmark_data_observed, args.benchmark_data_unobserved, tokenizer, isAlpaca)
    
    iid_dataloader = data_module.iid_dataloader()
    benchmark_observed_dataloader = data_module.benchmark_observed_dataloader()
    benchmark_unobserved_dataloader = data_module.benchmark_unobserved_dataloader()
    
    for epoch in args.valid_epochs:
        
        print("Validating epoch {}".format(epoch))
        epoch_result = {"epoch": epoch}
        
        model_path = os.path.join(args.model_dir, "epoch-{}".format(epoch))
        print("Loading model from {}".format(model_path))
        model = ModelClass.from_pretrained(model_path, torch_dtype=torch.bfloat16)
        
        print("Validating on IID data")
        iid_loss = validate(model, iid_dataloader)
        print("IID loss: {}".format(iid_loss))
        
        print("Validating on benchmark observed instruction")
        benchmark_observed_loss = validate(model, benchmark_observed_dataloader)
        print("Benchmark observed loss: {}".format(benchmark_observed_loss))
        
        print("Validating on benchmark unobserved instruction")
        benchmark_unobserved_loss = validate(model, benchmark_unobserved_dataloader)
        print("Benchmark unobserved loss: {}".format(benchmark_unobserved_loss))
        
        epoch_result["iid_loss"] = iid_loss
        epoch_result["benchmark_observed_loss"] = benchmark_observed_loss
        epoch_result["benchmark_unobserved_loss"] = benchmark_unobserved_loss
    
        results.append(epoch_result)
    
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(args.output_dir, "results.csv"), index=False)
        

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="../models/forgetting-experiments/t5")
    parser.add_argument("--valid_epochs", type=list, default=[1, 2, 3, 4, 5, 10, 15, 20, 25])
    parser.add_argument("--iid_data_dir", type=str, default="./data/flan_iid_val")
    parser.add_argument("--benchmark_data_observed", type=str, default="./data/flan_mmlu_observed")
    parser.add_argument("--benchmark_data_unobserved", type=str, default="./data/flan_mmlu_unobserved")
    
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--output_dir", type=str, default="./results/t5_valid")
    
    args = parser.parse_args()
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    run(args)


if __name__ == "__main__":
    main()