import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from dataset_loader import get_encoded_dataloader_from_texts, DatasetManager, get_encoded_dataloader_from_texts, extract_layers_mlp_ios
import argparse

parser = argparse.ArgumentParser()
# model
parser.add_argument('--model_dir', type=str, default='xxx/llama/llama-2-7b-hf', help='model directory')
parser.add_argument('--batch_size', type=int, default=4, help='batch_size')
# dataset
parser.add_argument('--dataset_name', type=str, default='arxiv-math', help='dataset_name, support :: for concatenation')
parser.add_argument('--downsample_rate', type=float, default=0.1, help='downsample_rate')
parser.add_argument('--max_len', type=int, default=1024, help='max_tokens')
# others
parser.add_argument('--save_path', type=str, default='xxx/llama_reader_larger/arxiv-math/', help='save_path')
parser.add_argument('--device_num', type=int, default=0, help='device_num')


# define hook function for input and output of mlp layers
def get_activation(name, activations):
    def hook(model, input, output):
        activations[name] = input[0]
    return hook

def get_output(name, activations):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


def Nth_prefix(N):
    return f'layer_{N}_mlp'

def Nth_input_name(N):
    return f"{Nth_prefix(N)}_input"

def Nth_output_name(N):
    return f"{Nth_prefix(N)}_output"



# extract N-th layer MLP input
def extract_layers_mlp_inputs(model, dataloader, device, start_layer_idx=0, end_layer_idx=31):
    activations = {}
    for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        model.model.layers[layer_idx].mlp.register_forward_hook(get_activation(Nth_input_name(layer_idx), activations))
    
    inputs = {Nth_input_name(layer_idx): [] for layer_idx in range(start_layer_idx, end_layer_idx + 1)}

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Extracting MLP inputs'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # forward pass
            _ = model(input_ids, attention_mask=attention_mask)

            # extract mlp inputs for all layers
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                mlp_input = activations[Nth_input_name(layer_idx)]
                inputs[Nth_input_name(layer_idx)].append(mlp_input.cpu())
                
            # clear activations
            del activations
            torch.cuda.empty_cache()
            
            # new activations
            activations = {}
            for layer_idx in range(start_layer_idx, end_layer_idx + 1):
                model.model.layers[layer_idx].mlp.register_forward_hook(get_activation(Nth_input_name(layer_idx), activations))
            # print(f'new activations')
                

    final_inputs = {Nth_input_name(layer_idx): torch.cat(inputs[Nth_input_name(layer_idx)]) for layer_idx in range(start_layer_idx, end_layer_idx + 1)}
    for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        print(f"Final shape for layer {layer_idx}: {final_inputs[Nth_input_name(layer_idx)].shape}")
    del inputs
    
    return final_inputs

# save tensor to file
def save_tensor(tensor, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(tensor, save_path)
    print(f'Tensor saved to {save_path}')


# main function
def read_intermidiates(model_dir, dataset_name, max_len, batch_size, save_path, ratio = 0.1, device = 'cuda:0', start_layer_idx=0, end_layer_idx=31):
    
    # get model and tokenizer
    tokenizer = LlamaTokenizer.from_pretrained(model_dir)
    tokenizer.pad_token = tokenizer.eos_token
    model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
    
    # freeze model to save memory
    for param in model.parameters():
        param.requires_grad = False
    model.eval()

    # get dataset and dataloader
    DM = DatasetManager()
    train_texts, _, _ = DM.get_dataset_texts(dataset_name)
    train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
    
    # extract and save intermediate layers
    with torch.no_grad():
        final_inputs = extract_layers_mlp_inputs(model, train_loader, device, start_layer_idx=start_layer_idx, end_layer_idx=end_layer_idx)
        for layer_idx in range(start_layer_idx, end_layer_idx + 1):
            save_tensor(final_inputs[Nth_input_name(layer_idx)], os.path.join(save_path, f'train_inputs_complete_{layer_idx}.pt'))


    

def test_read_intermediate_ios(model_dir, dataset_name, max_len, batch_size, save_path, ratio = 0.1, device = 'cuda:0', start_layer_idx=0, end_layer_idx=31):
    
    # get model and tokenizer
    tokenizer = LlamaTokenizer.from_pretrained(model_dir)
    tokenizer.pad_token = tokenizer.eos_token
    model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
    
    # freeze model to save memory
    for param in model.parameters():
        param.requires_grad = False
    model.eval()

    # get dataset and dataloader
    DM = DatasetManager()
    train_texts, _, _ = DM.get_dataset_texts(dataset_name)
    train_loader = get_encoded_dataloader_from_texts(train_texts, batch_size, tokenizer, max_len, downsample_rate=ratio, is_val=False)
    
    # extract and save intermediate layers
    with torch.no_grad():
        final_inputs, final_outputs = extract_layers_mlp_ios(model, train_loader, device, start_layer_idx=start_layer_idx, end_layer_idx=end_layer_idx)
        
        print(f'finish testing')
        # for layer_idx in range(start_layer_idx, end_layer_idx + 1):
        #     save_tensor(final_inputs[layer_idx], os.path.join(save_path, f'train_inputs_complete_{layer_idx}.pt'))
        #     save_tensor(final_outputs[layer_idx], os.path.join(save_path, f'train_outputs_complete_{layer_idx}.pt'))
        # del final_inputs, final_outputs
        # torch.cuda.empty_cache()


if __name__ == "__main__":
    args = parser.parse_args()
    
    os.makedirs(args.save_path, exist_ok=True)

    read_intermidiates(
        model_dir=args.model_dir,
        dataset_name=args.dataset_name,
        max_len=args.max_len,
        batch_size=args.batch_size,
        save_path=args.save_path,
        ratio=args.downsample_rate,
        device=f'cuda:{args.device_num}',
        start_layer_idx=0,
        end_layer_idx=31
    )
    
    # test_read_intermediate_ios(
    #     model_dir=args.model_dir,
    #     dataset_name=args.dataset_name,
    #     max_len=args.max_len,
    #     batch_size=args.batch_size,
    #     save_path=args.save_path,
    #     ratio=args.downsample_rate,
    #     device=f'cuda:{args.device_num}',
    #     start_layer_idx=0,
    #     end_layer_idx=31
    # )