import argparse
import copy

parser = argparse.ArgumentParser(description='sp')
parser.add_argument('--start', type=int, default=0)
parser.add_argument('--end', type=int, default=100)
parser.add_argument('--index', type=int, default=1)
parser.add_argument('--gpu_index', type=int, nargs='+', default=[0])
parser.add_argument('--outdir', type=str, default='outdir0')
args = parser.parse_args()
import os
from typing import Dict, Optional, Sequence


os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)[1:-1]

import sys
sys.path.append('/path/to/EAGLE')
import torch
import torch.nn.functional as F
from tqdm import tqdm
from eagle.model.modeling_llama_kv import LlamaForCausalLM
from torch.utils.data import Dataset
import json
import numpy as np
import random
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index



bigname="/path/to/LlamaGen-T2I-2"
# bigname = "/home/lyh/weights/hf/llama/7B/"
# smallname = "/home/lyh/weights/hf/llama/7B/"

class SupervisedDataset(Dataset):
    

    def __init__(self, data_base_path, uncond_embedding):
        super(SupervisedDataset, self).__init__()
        self.code_path = sorted(os.listdir(os.path.join(data_base_path, "code")))
        self.text_path = sorted(os.listdir(os.path.join(data_base_path, "text_features")))
        self.code_base_path = os.path.join(data_base_path, "code")
        self.text_base_path = os.path.join(data_base_path, "text_features")
        self.uncond_embedding = uncond_embedding

    def __len__(self):
        return len(self.code_path)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        assert os.path.basename(self.code_path[i]) == os.path.basename(self.text_path[i])
        input_ids = np.load(os.path.join(self.code_base_path, self.code_path[i]))
        input_ids = torch.from_numpy(input_ids).long()
        cond_idx = np.load(os.path.join(self.text_base_path, self.text_path[i]))
        # set cond_idx to 1000 with 0.1 probability
        if random.random() < 0.1:
            cond_idx = self.uncond_embedding.clone().detach().unsqueeze(0)
            attention_mask = torch.ones((1, 1144))
        else:
            cond_idx = torch.from_numpy(cond_idx)
            attention_mask = torch.ones((1, 1144))
            attention_mask[0, :120 - cond_idx.shape[1]] = 0
            cond_padding = torch.zeros((1, 120 - cond_idx.shape[1], cond_idx.shape[2]))
            cond_idx = torch.cat([cond_padding, cond_idx], dim=1)
        loss_mask = torch.ones((1, 1144))
        loss_mask[:, :120] = 0
        
        ret = dict(
            cond_idx=cond_idx,
            input_ids=input_ids,
            attention_mask=attention_mask,
            loss_mask=loss_mask,
            # cfg_scale=cfg_scale
        )

        return ret
    def shuffle(self, seed: Optional[int] = None):
        if seed is not None:
            np.random.seed(seed)
        # Generate a permutation of indices
        perm = np.random.permutation(len(self.code_path))
        # Shuffle code_path and text_path based on the permuted indices
        self.code_path = [self.code_path[i] for i in perm]
        self.text_path = [self.text_path[i] for i in perm]
        return self

    def select(self, indices: Sequence[int]):
        # Select specific indices for both code_path and text_path
        self.code_path = [self.code_path[i] for i in indices]
        self.text_path = [self.text_path[i] for i in indices]
        return self

def longest_common_prefix(list1, list2):
    prefix_length = 0
    min_length = min(len(list1), len(list2))

    for i in range(min_length):
        if list1[i] == list2[i]:
            prefix_length += 1
        else:
            break

    common_prefix = list1[:prefix_length]
    return common_prefix, prefix_length


bigmodel = LlamaForCausalLM.from_pretrained(bigname,  device_map="auto",torch_dtype=torch.float16)
uncond_embedding = bigmodel.model.cls_embedding.uncond_embedding

ds = SupervisedDataset('/path/to/laion-coco-train-2', uncond_embedding)
ds = ds.shuffle(seed=42)
ds = ds.select(range(args.start, args.end))
print(ds)
bigmodel.eval()

@torch.no_grad()
def ge(data):
    input_ids=data["input_ids"]
    # input_ids=torch.cat([input_ids, input_ids])
    cond_idx=data["cond_idx"].to(torch.float16)
    # cond_null = torch.ones_like(cond_idx) * 1000
    # cond_idx = torch.cat([cond_idx, cond_null])
    loss_mask=data["loss_mask"]
    attention_mask = data["attention_mask"]
    # loss_mask = torch.cat([loss_mask, loss_mask])
    # cfg_scale = data['cfg_scale']
    outs_big = bigmodel(cond_idx=cond_idx.cuda(), input_ids=input_ids.cuda(),attention_mask=attention_mask.cuda(), output_hidden_states=True)
    hidden_state_big = outs_big.hidden_states[-1]
    max_prob_tokens_big = torch.argmax(outs_big.logits, dim=-1)
    probs = torch.softmax(outs_big.logits, dim=-1)
    maxp=probs[0].max(dim=1).values
    # td={"cond_idx":cond_idx, "input_ids":input_ids.cpu(),"hidden_state":hidden_state_big.cpu(),"loss_mask":loss_mask.cpu(), "cfg_scale":cfg_scale}
    td={"cond_idx":cond_idx, "input_ids":input_ids.cpu()[0],"hidden_state":hidden_state_big.cpu()[0],"loss_mask":loss_mask.cpu()[0],'attention_mask':attention_mask[0]}
    return td

outdir = f'{args.outdir}/{args.index}'
if not os.path.exists(outdir):
    os.makedirs(outdir)

def writedata(name,data_point):
    if not os.path.exists(name):
        os.makedirs(name)
    current_length=len(os.listdir(name))
    idx=current_length
    torch.save(data_point, f'{name}/data_{idx}.ckpt')


for data in tqdm(ds):
    outdata = ge(data)
    writedata(outdir,outdata)


