import os
import sys
import numpy as np
import torch
import json
import torch
import torch.utils.data as data
from torchvision import transforms
from PIL import Image

import transformers

sys.path.append('../ecog-multimodal/vil_embeds/SLIP')
from models import SIMCLR_VITB16

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
general_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

class NLVRDataset(data.Dataset):
    def __init__(self, args, text_processor, json_data, split = 'train'):
        self.args = args
        self.json_data = json_data
        self.split = split
        self.text_processor = text_processor
    def __len__(self):
        return len(self.json_data)
    def __getitem__(self, idx):
        imgs_sent = self.json_data[idx]
        img1 = Image.open(imgs_sent['img0']).convert('RGB')
        img1 = general_transform(img1)
        img2 = Image.open(imgs_sent['img1']).convert('RGB')
        img2 = general_transform(img2)

        sentence = imgs_sent['sentence']
        encoded_tokens = self.text_processor(sentence, return_tensors = 'pt', padding = 'max_length', max_length = 60)
        input_ids = encoded_tokens['input_ids'].squeeze()
        label = imgs_sent['label']
        if label == 'True':
            label = torch.tensor(1)
        else:
            label = torch.tensor(0)

        return img1, img2, input_ids, sentence, label
    
def NLVRLoader(args):
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.text_model_str)
    with open(os.path.join(args.data_path, 'train.json'), 'r') as f:
        train_json = json.load(f)
    train_nlvr = NLVRDataset(args, tokenizer, train_json)
    train_nlvr_loader = data.DataLoader(train_nlvr, batch_size = args.batch_size, num_workers= 20)

    with open(os.path.join(args.data_path, 'dev.json'), 'r') as f:
        valid_json = json.load(f)
    valid_nlvr = NLVRDataset(args, tokenizer, valid_json, split = 'val')
    valid_nlvr_loader = data.DataLoader(valid_nlvr, batch_size = args.batch_size, num_workers= 20)

    with open(os.path.join(args.data_path, 'test.json'), 'r') as f:
        test_json = json.load(f)
    test_nlvr = NLVRDataset(args, tokenizer, test_json, split = 'test')
    test_nlvr_loader = data.DataLoader(test_nlvr, batch_size = 1, num_workers= 20)
    return train_nlvr_loader, valid_nlvr_loader, test_nlvr_loader
    
if __name__ == '__main__':
    from types import SimpleNamespace
    import json
    from tqdm import tqdm
    args = SimpleNamespace()
    args.data_path = 'fusion-model/nlvr-data'
    with open(os.path.join(args.data_path, 'dev.json'), 'r') as f:
        json_data = json.load(f)
    tokenizer = transformers.AutoTokenizer.from_pretrained('princeton-nlp/sup-simcse-bert-base-uncased')
    dataset = NLVRDataset(args, tokenizer, json_data, split = 'dev')
    dataloader = data.DataLoader(dataset, batch_size = 8)
    print(dataset[0])
    weights_path = '../ecog-multimodal/vil_embeds/pretrained_models'

    #For SimCLR testing
    model = SIMCLR_VITB16()
    model_state_dict = torch.load(os.path.join(weights_path, 'simclr_base_25ep.pt'))['state_dict']
    for key in list(model_state_dict.keys()):
        model_state_dict[key.replace('module.', '')] = model_state_dict.pop(key)
    model.load_state_dict(model_state_dict)

    text_model = transformers.AutoModel.from_pretrained('princeton-nlp/sup-simcse-bert-base-uncased')

    for batch in tqdm(dataloader, desc = 'Testing dataloading...'):
        img1, img2, input_ids, sentence, label = batch
        #Testing SimCLR loading with current dataloading techniques
        # with torch.no_grad():
        #     output = model.encode_image(img1)
        # outputs = text_model(input_ids)
        # print(outputs.pooler_output.shape)
        # break

    pass
