import os
import glob
from tqdm import tqdm
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import argparse


parser = argparse.ArgumentParser(description='')
parser.add_argument('--data', type=str, default=False)
parser.add_argument('--text', type=str, default=False)
args = parser.parse_args()


processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dataset_root = 'DomainBed_GDG/data/{}'.format(args.data)

for domain in os.listdir(dataset_root):
    if '.' in domain or 'text' in domain:
        print(domain)
        continue
    print('Domain: {} start'.format(domain))
    
    for data_class in os.listdir(os.path.join(dataset_root, domain)):
        if '.' in data_class:
            print(data_class)
            continue
        data_path = dataset_root + '/' + domain + '/' + data_class
        data_files = sorted(glob.glob(data_path + '/*'))

        save_path = 'DomainBed_GDG/data/{}/{}/{}/{}'.format(args.text, args.data, domain, data_class)
        os.makedirs(save_path, exist_ok=True)
        
        for data_file in tqdm(data_files):
            if 'texts' in data_file or 'txt' in data_file:
                continue
            
            try:
                image = Image.open(data_file).convert('RGB')
            except:
                print(data_file)
                
            inputs = processor(image, return_tensors="pt").to(device, torch.float16)

            generated_ids = model.generate(**inputs, max_new_tokens=20)
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
            
            file_name = os.path.basename(data_file)
            if 'jpg' in file_name:
                save_file = file_name.replace('.jpg', '.txt')
            elif 'png' in file_name:
                save_file = file_name.replace('.png', '.txt')
            else:
                raise Exception("not expected image format")
            
            with open(os.path.join(save_path, save_file), 'w') as f:
                f.write(generated_text)