from transformers import AutoModelForCausalLM
import torch
from tqdm import tqdm 
import pickle 
import os 


def create_distillation_dataset_cache(model, train_dl, device, distill_mode, cache_dir, is_train):
    """
    Creates a dataset for distillation from a given model and data loaders. Dataset is saved in cache_dir

    """
    if not distill_mode:
        print('Distillation off, not generating distillation dataset')
        return []
    
    model = model.to(device).eval()

    for batch_idx, batch in tqdm(enumerate(train_dl), mininterval=5, desc='Generating distillation dataset', total=len(train_dl)):
        input_ids = batch['input_ids'][:, :-1].to(model.device)
        attention_mask = batch['attention_mask'][:, :-1].to(model.device)

        with torch.no_grad():
            outputs =  model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        if distill_mode in ['hs_last']:
            idx1, idx2 = int(len(outputs.hidden_states) * 0.50), -1
            hidden1, hidden2 = outputs.hidden_states[idx1].detach().cpu(), outputs.hidden_states[idx2].detach().cpu()
            
        else:
            raise NotImplementedError(f'Unsupported distill_model {distill_mode} in create_distillation_dl')
        
        distill_data = {
            #'teacher_logits': outputs.logits.detach().cpu(),
            'teacher_logits': 0,
            'teacher_hidden1': hidden1,
            'teacher_hidden2': hidden2,
            'layer_idx1': idx1,
            'layer_idx2': idx2
        }

        write_dir = os.path.join(cache_dir, 'distill_cache')
        os.makedirs(write_dir, exist_ok=True)

        filename = os.path.join(write_dir, f"{'train' if is_train else 'test'}_{batch_idx}.pt")
        torch.save(distill_data, filename)

        del outputs; del distill_data

    torch.cuda.empty_cache()
    print('Finished generating distillation dataset')
    return

if __name__ == '__main__':
    from data_utils import get_dataloaders
    from transformers import AutoTokenizer
    import pdb

    model_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    train_dl, _ = get_dataloaders(
        tokenizer,
        dataset_name="wikitext2",
        num_train_samples=256,
        num_test_samples=256,
        batch_size=4,
        random_state=42,
        debug=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    create_distillation_dataset_cache(model_name, train_dl, distill_mode='hidden_states', device=device)
    pdb.set_trace() 

