import os
import json
import random
from tqdm import tqdm
import torch
import numpy as np
from PIL import Image
from datasets import load_dataset


def get_wikitext2(n_samples, tokenizer, seqlen, seed=42):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')    
    random.seed(seed)
    input_ids = []
    for _ in range(n_samples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        # tar = inp.clone()
        # tar[:, :-1] = -100
        input_ids.append(inp)
    
    input_ids = torch.cat(input_ids, dim=0)
    dataloader = {
        'input_ids': input_ids,  # (n_samples, seqlen)
    }
    return dataloader


def load_image(image_path):
    # Load the image using use PIL, we don't support tcs_loader
    return Image.open(image_path).convert('RGB')

def get_multimodal_calib_dataset(model, dev, args):
    data_path = args.data_path
    n_samples = args.n_samples
    image_folder = args.image_folder

    if data_path.endswith(".jsonl"):
        dataset = []
        with open(data_path, "r") as json_file:
            for line in json_file:
                dataset.append(json.loads(line.strip()))
    elif data_path.endswith(".json"):
        with open(data_path, "r") as json_file:
            dataset = json.load(json_file)
    else:
        raise ValueError(f"Unsupported file type: {data_path}")

    rng = np.random.default_rng(seed=42)
    rng.shuffle(dataset)

    data_list = []
    cnt = 0
    torch.manual_seed(seed=42)
    model.model.model.embed_tokens = model.model.model.embed_tokens.to(dev)
    model.model.visual = model.model.visual.to(dev)
    
    inputs_embeds = torch.zeros(
        (args.n_samples, args.seqlen, model.model.config.hidden_size), dtype=model.model.dtype, device=dev
    )
    input_ids = torch.zeros(
        (args.n_samples, args.seqlen), dtype=torch.int64, device=dev
    )
    vision_mask = torch.zeros(
        (args.n_samples, args.seqlen), dtype=torch.bool, device=dev
    )
    answer_mask = torch.zeros(
        (args.n_samples, args.seqlen), dtype=torch.bool, device=dev
    )
    image_grid_thw = torch.zeros(
        (args.n_samples, 3), dtype=torch.int, device=dev
    )
    
    pbar = tqdm(total=n_samples, desc="Collecting calib data", unit="sample")
    while True:
        i = torch.randint(0, len(dataset), (1,)).item()
        data_item = dataset[i]
        if 'image' in data_item and len(data_item['image']) != 0:
            if type(data_item['image']) == list:
                images = []
                for image_path in data_item['image']:
                    # Merge the image path
                    full_image_path = os.path.join(image_folder, image_path)
                    image = load_image(full_image_path)
                    images.append(image)
            else:
                images = []
                image_path = data_item['image']
                full_image_path = os.path.join(image_folder, image_path)
                image = load_image(full_image_path)
                images.append(image)
        else:
            images = None
        
        data_dict = model.preprocess_data(images, data_item)
        if data_dict['input_ids'].numel() <= args.seqlen or data_dict['input_ids'][args.seqlen-1] == 151655:
            # if the truncated position is image, continue
            continue
        else:
            data_list.append(data_dict)
            input_ids[cnt:cnt+1] = data_dict['input_ids'].unsqueeze(0)[:, :args.seqlen]
            image_grid_thw[cnt:cnt+1] = data_dict['image_grid_thw']
            prompt_inputs, prompt_kwargs = model.generate_input(model.data_collator(data_list))
            inputs_embeds[cnt:cnt+1] = prompt_inputs['inputs_embeds'][:, :args.seqlen, :]
            vision_mask[cnt:cnt+1] = prompt_kwargs['vision_mask'][:, :args.seqlen]
            answer_mask[cnt:cnt+1] = prompt_kwargs['caption_mask'][:, :args.seqlen]
            
            data_list = []
            cnt += 1
            pbar.update(1)            
        
        if cnt == n_samples:
            break
    pbar.close()

    
    model.model.model.embed_tokens = model.model.model.embed_tokens.cpu()
    model.model.visual = model.model.visual.cpu()
    
    torch.cuda.empty_cache()

    dataloader = {
        'input_ids': input_ids,                 # (n_sample, seqlen)
        'image_grid_thw': image_grid_thw,       # (n_sample, 3)
        'inputs_embeds': inputs_embeds,         # (n_sample, seqlen, hidden_size)
        'vision_mask': vision_mask,             # (n_sample, seqlen)
        'answer_mask': answer_mask,             # (n_sample, seqlen)
    }   
    
    return dataloader

def get_multimodal_calib_dataset_llava(model, dev, args):
    data_path = args.data_path
    n_samples = args.n_samples
    image_folder = args.image_folder

    if data_path.endswith(".jsonl"):
        dataset = []
        with open(data_path, "r") as json_file:
            for line in json_file:
                dataset.append(json.loads(line.strip()))
    elif data_path.endswith(".json"):
        with open(data_path, "r") as json_file:
            dataset = json.load(json_file)
    else:
        raise ValueError(f"Unsupported file type: {data_path}")

    rng = np.random.default_rng(seed=42)
    rng.shuffle(dataset)

    data_list = []
    cnt = 0
    torch.manual_seed(seed=42)
    model.model.model.embed_tokens = model.model.model.embed_tokens.to(dev)
    model.model.model.vision_tower = model.model.model.vision_tower.to(dev)
    
    inputs_embeds = torch.zeros(
        (args.n_samples, args.seqlen, model.model.config.hidden_size), dtype=model.model.dtype, device=dev
    )
    # input_ids = torch.zeros(
    #     (args.n_samples, args.seqlen), dtype=torch.int64, device=dev
    # )
    vision_mask = torch.zeros(
        (args.n_samples, args.seqlen), dtype=torch.bool, device=dev
    )
    answer_mask = torch.zeros(
        (args.n_samples, args.seqlen), dtype=torch.bool, device=dev
    )
    # image_grid_thw = torch.zeros(
    #     (args.n_samples, 3), dtype=torch.int, device=dev
    # )
    
    pbar = tqdm(total=n_samples, desc="Collecting calib data", unit="sample")
    while True:
        i = torch.randint(0, len(dataset), (1,)).item()
        data_item = dataset[i]
        if 'image' in data_item and len(data_item['image']) != 0:
            if type(data_item['image']) == list:
                images = []
                for image_path in data_item['image']:
                    # Merge the image path
                    full_image_path = os.path.join(image_folder, image_path)
                    image = load_image(full_image_path)
                    images.append(image)
            else:
                images = []
                image_path = data_item['image']
                full_image_path = os.path.join(image_folder, image_path)
                image = load_image(full_image_path)
                images.append(image)
        else:
            images = None
        
        data_dict = model.preprocess_data(images, data_item)
        if data_dict['input_ids'].numel() <= 256:
            continue
        else:
            data_list.append(data_dict)
            # input_ids[cnt:cnt+1] = data_dict['input_ids'].unsqueeze(0)[:, :args.seqlen]
            prompt_inputs, prompt_kwargs = model.generate_input(model.data_collator(data_list))

            inputs_embeds[cnt:cnt+1] = prompt_inputs['inputs_embeds'][:, :args.seqlen, :]
            # 730 vision tokens: vmask.sum()=730
            vision_mask[cnt:cnt+1] = prompt_kwargs['vision_mask'][:, :args.seqlen]
            answer_mask[cnt:cnt+1] = prompt_kwargs['caption_mask'][:, :args.seqlen]
            
            data_list = []
            cnt += 1
            pbar.update(1)            
        
        if cnt == n_samples:
            break
    pbar.close()

    model.model.model.embed_tokens = model.model.model.embed_tokens.cpu()
    model.model.model.vision_tower = model.model.model.vision_tower.cpu()
    
    torch.cuda.empty_cache()

    dataloader = {
        'input_ids': None,                      # (n_sample, seqlen)
        'image_grid_thw': None,                 # (n_sample, 3)
        'inputs_embeds': inputs_embeds,         # (n_sample, seqlen, hidden_size)
        'vision_mask': vision_mask,             # (n_sample, seqlen)
        'answer_mask': answer_mask,             # (n_sample, seqlen)
    }   
    
    return dataloader