import os
import math
import pandas as pd
import numpy as np
import torch
from copy import deepcopy
import random
import argparse
import linecache
import json
from tqdm import tqdm
from transformers import AutoTokenizer
from selftokmodel import SelftokModel

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Example")

    parser.add_argument("--name", type=str, default="")
    parser.add_argument("--task", type=int, default=2)
    parser.add_argument('--cfg_type', type=str, default='adaptive')
    parser.add_argument("--cfg", type=float, default=10.0)
    parser.add_argument("--oripath", type=str, default="")
    parser.add_argument("--path", type=str, default="")
    parser.add_argument("--ckpt", type=int, default=200)
    parser.add_argument("--num_gen", type=int, default=1)
    parser.add_argument("--reverse", action="store_true")
    parser.add_argument('--test_path', type=str, default='')
    parser.add_argument("--max_completion_length", type=int, default=512)
    
    args = parser.parse_args()


    if args.task==0:
        prompts_file = args.test_path
        pdframe = pd.read_csv(prompts_file)
        prompt_list = pdframe["Prompts"].values.tolist()

    if args.task==1:
        ff = args.test_path
        f = open(ff)
        prompt_list = f.read().split('\n')
    
    if args.task==2:
        prompt_list = []
        numprompt = len(linecache.getlines(args.test_path))
        for i in range(numprompt):
            curstr = linecache.getline(args.test_path, i+1)
            txt = json.loads(curstr)
            prompt_list.append(txt['prompt'])
        

    if args.task == 3:
        prompt_list = []
        df = pd.read_csv(args.test_path)
        for i, line in df.iterrows():
            prompt = line['text']
            if prompt in prompt_list:
                continue
            else:
                prompt_list.append(prompt)
        print(len(prompt_list))
        
    input_ids  = []

    TEXT_VOCAB_SIZE = 128256
    IMG_VOCAB_SIZE = 32768
    special_tokens_id = {
        "bos": 128000,
        "eos": 128001,
        "pad": 128002,
        "boi": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE,
        "eoi": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 1,
        "cfg": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 2,
        "rep": TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE + 3,
        "ignore": -100,
    }

    if args.ckpt==0:
        mymodel = SelftokModel.from_pretrained(args.oripath, device_map=torch.device("cuda:0"), torch_dtype=torch.bfloat16, use_cache=True,revision='main')
    
    else:
        mymodel = SelftokModel.from_pretrained(os.path.join(args.path, f'checkpoint-{args.ckpt}/'), device_map=torch.device("cuda:0"), torch_dtype=torch.bfloat16, use_cache=True,revision='main')

    mymodel.eval()

    tokenizer_path = '' # selftok tokenizer path
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    numeach = 2

    tokenizer.pad_token_id = special_tokens_id["pad"]
    oribos = torch.LongTensor([special_tokens_id['bos']]).unsqueeze(0).cuda()
    orieos = torch.LongTensor([special_tokens_id['eos']]).unsqueeze(0).cuda()
    oriboi = torch.LongTensor([special_tokens_id['boi']]).unsqueeze(0).cuda()
    orieoi = torch.LongTensor([special_tokens_id['eoi']]).unsqueeze(0).cuda()

    upper = math.ceil(len(prompt_list)/32) 
    cur=0
    
    for i in tqdm(range(0, upper)):
        st, ed = i*32, min(i*32+32, len(prompt_list))
        instruction = tokenizer(
            prompt_list[st:ed],
            return_tensors="pt",
            padding='longest',
            add_special_tokens=False,
        ).to('cuda:0')

        bsz, L, dtype = instruction['input_ids'].size(0), instruction['input_ids'].size(1), instruction['input_ids'].dtype

        cfg_content = torch.ones((bsz, L), dtype=dtype) * special_tokens_id["cfg"]
        cfg_content = cfg_content.cuda()

        boi = oriboi.repeat_interleave(bsz, dim=0)
        eoi = orieoi.repeat_interleave(bsz, dim=0)
        bos = oribos.repeat_interleave(bsz, dim=0)
        print(instruction['input_ids'].size(), boi.size(), eoi.size())

        prompt_ids = torch.cat([bos, instruction['input_ids'], boi], dim=1)

        if args.cfg is not None:
            cfg_ids = torch.cat([bos, cfg_content, boi], dim=1)

        attention_mask1 = torch.ones((bsz, 1), dtype=torch.int32, device=prompt_ids.device)
        attention_mask2 = instruction['attention_mask']
        attention_mask3 = torch.ones((bsz, 1), dtype=torch.int32, device=prompt_ids.device)
        
        prompt_mask = torch.cat((attention_mask1, attention_mask2, attention_mask3), dim=1)
        if args.cfg is not None:
            prompt_mask = prompt_mask.repeat(2,1)
        if args.num_gen >1:
            prompt_mask = torch.repeat_interleave(prompt_mask, args.num_gen, dim=0) 
            if args.cfg is not None:
                cfg_ids = torch.repeat_interleave(cfg_ids, args.num_gen, dim=0) 
            prompt_ids = torch.repeat_interleave(prompt_ids, args.num_gen, dim=0) 

        if args.cfg is not None:
            prompt_ids = torch.cat([prompt_ids, cfg_ids],dim=0)
        
        image_ids = mymodel.generate_image(
            input_ids=prompt_ids, 
            ori_attention_mask=prompt_mask, 
            use_past=True, 
            top_k=4096, 
            top_p=0.9, 
            guidance_scale=args.cfg,
            cfg_type = args.cfg_type,
            img_seq_len=args.max_completion_length,
            image_vocab_slice=(TEXT_VOCAB_SIZE, TEXT_VOCAB_SIZE + IMG_VOCAB_SIZE)
            )

        npy = image_ids.cpu().numpy()
        if args.reverse:
            npy = npy[:,::-1]

        print(npy.shape)
        numnumpy = npy.shape[0] // args.num_gen
        for kk in range(numnumpy):
            name = f'{cur}.npy'
            os.makedirs(args.name,exist_ok=True)
            np.save(f'{args.name}/{name}', npy[kk*args.num_gen:(kk+1)*args.num_gen])
            cur+=1
