import argparse
import copy

parser = argparse.ArgumentParser(description='sp')
parser.add_argument('--start', type=int, default=0)
parser.add_argument('--end', type=int, default=100)
parser.add_argument('--index', type=int, default=1)
parser.add_argument('--gpu_index', type=int, nargs='+', default=[0])
parser.add_argument('--outdir', type=str, default='outdir0')
args = parser.parse_args()
import os
from typing import Dict, Optional, Sequence


os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)[1:-1]

import sys
sys.path.append('/path/to/EAGLE')
import torch
import torch.nn.functional as F
from tqdm import tqdm
from eagle.model.modeling_llama_kv import LlamaForCausalLM
from torch.utils.data import Dataset
import json
import numpy as np
import random
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index



bigname="/path/to/model/LlamaGen-3B"
# bigname = "/home/lyh/weights/hf/llama/7B/"
# smallname = "/home/lyh/weights/hf/llama/7B/"

class SupervisedDataset(Dataset):
    

    def __init__(self, data_base_path):
        super(SupervisedDataset, self).__init__()
        self.code_path = os.path.join(data_base_path, "data.npy")
        self.label_path = os.path.join(data_base_path, "label.npy")
        # self.cfg_scale_path = os.path.join(data_base_path, "cfg_scale.npy")
        self.code_files = np.load(self.code_path)
        self.label_files = np.load(self.label_path)
        # self.cfg_scale_files = np.load(self.cfg_scale_path)

    def __len__(self):
        return len(self.code_files)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:

        input_ids = self.code_files[i]
        input_ids = torch.from_numpy(input_ids).long()
        cond_idx = self.label_files[i:i+1]
        # set cond_idx to 1000 with 0.1 probability
        if random.random() < 0.1:
            cond_idx = np.full_like(cond_idx, 1000)
        cond_idx = torch.from_numpy(cond_idx).long()
        cond_input = torch.cat([cond_idx, input_ids], dim=-1)
        target = torch.full_like(cond_input, IGNORE_TOKEN_ID)
        # cfg_scale = self.cfg_scale_files[i:i+1]
        # cfg_scale = torch.from_numpy(cfg_scale).float()
        target[1:] = input_ids.clone().detach()
        
        
        ret = dict(
            cond_idx=cond_idx,
            input_ids=input_ids,
            attention_mask=torch.ones_like(cond_input),
            labels=target,
            loss_mask=torch.ones_like(target),
            # cfg_scale=cfg_scale
        )

        return ret
    def shuffle(self, seed: Optional[int] = None):
        if seed is not None:
            np.random.seed(seed)
        perm = np.random.permutation(len(self))
        self.code_files = self.code_files[perm]
        self.label_files = self.label_files[perm]
        return self
    
    def select(self, indices: Sequence[int]):
        self.code_files = self.code_files[indices]
        self.label_files = self.label_files[indices]
        return self

def longest_common_prefix(list1, list2):
    prefix_length = 0
    min_length = min(len(list1), len(list2))

    for i in range(min_length):
        if list1[i] == list2[i]:
            prefix_length += 1
        else:
            break

    common_prefix = list1[:prefix_length]
    return common_prefix, prefix_length

ds = SupervisedDataset('/path/to/dataset/imagenet_code_c2i_flip/sampled_100_000')
ds = ds.shuffle(seed=42)
ds = ds.select(range(args.start, args.end))
print(ds)
bigmodel = LlamaForCausalLM.from_pretrained(bigname,  device_map="auto",torch_dtype=torch.float16)
bigmodel.eval()











@torch.no_grad()
def ge(data):
    input_ids=data["input_ids"].unsqueeze(0)
    # input_ids=torch.cat([input_ids, input_ids])
    cond_idx=data["cond_idx"]
    # cond_null = torch.ones_like(cond_idx) * 1000
    # cond_idx = torch.cat([cond_idx, cond_null])
    loss_mask=data["loss_mask"].unsqueeze(0)
    # loss_mask = torch.cat([loss_mask, loss_mask])
    # cfg_scale = data['cfg_scale']
    outs_big = bigmodel(cond_idx=cond_idx.cuda(), input_ids=input_ids.cuda(), output_hidden_states=True)
    hidden_state_big = outs_big.hidden_states[-1]
    max_prob_tokens_big = torch.argmax(outs_big.logits, dim=-1)
    probs = torch.softmax(outs_big.logits, dim=-1)
    maxp=probs[0].max(dim=1).values
    # td={"cond_idx":cond_idx, "input_ids":input_ids.cpu(),"hidden_state":hidden_state_big.cpu(),"loss_mask":loss_mask.cpu(), "cfg_scale":cfg_scale}
    td={"cond_idx":cond_idx, "input_ids":input_ids.cpu()[0],"hidden_state":hidden_state_big.cpu()[0],"loss_mask":loss_mask.cpu()[0]}
    return td

outdir = f'{args.outdir}/{args.index}'
if not os.path.exists(outdir):
    os.makedirs(outdir)

def writedata(name,data_point):
    if not os.path.exists(name):
        os.makedirs(name)
    current_length=len(os.listdir(name))
    idx=current_length
    torch.save(data_point, f'{name}/data_{idx}.ckpt')


for data in tqdm(ds):
    outdata = ge(data)
    writedata(outdir,outdata)


