import os
import sys
from argparse import ArgumentParser
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from PIL import Image
from datasets import load_dataset
from reward_distill.model.clip_mlp_model import MLP
from reward_models.reward_interface import UnifiedReward
from tqdm import tqdm
import clip

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize


# differentiable transform for CLIP
def _transform(n_px):
    return Compose([
        ToTensor(),
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--do_clip_preproc", action="store_true")
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    # meta-parameters
    device = torch.device(args.device)

    # load the dataset and do preprocessing
    if args.do_clip_preproc or os.path.exists("reward_distill/data/imagereward_1K_features.pt") == False:
        ground_truth = UnifiedReward("imagereward")

        dataset = load_dataset("THUDM/ImageRewardDB", "1k")

        # load clip model
        model2, _ = clip.load("ViT-L/14", device=device) # using the same clip as aesthetic
        preprocess = _transform(224)

        # get the features
        features = []
        for i in tqdm(range(len(dataset['train']))):
            try:
                data = dataset['train'][i]
                # get image features
                image = data['image'].convert("RGB")

                score = ground_truth.score(ToTensor()(image), data['prompt']).detach().cpu()

                image_input = preprocess(image).unsqueeze(0).to(device)
                with torch.no_grad():
                    image_features = model2.encode_image(image_input).cpu()
                
                # get text features
                text_input = clip.tokenize(data['prompt'], truncate=True).to(device)
                with torch.no_grad():
                    text_features = model2.encode_text(text_input).cpu()

                features.append({
                    'image': image_features,
                    'text': text_features,
                    'score': score
                })

            except Exception as e:
                print(e)
                print(f"Error at {i}")
                continue

        # save the features
        torch.save(features, "reward_distill/data/imagereward_1K_features.pt")

    # load the features
    features = torch.load("reward_distill/data/imagereward_1K_features.pt")

    # construct the dataset
    x1 = torch.stack([f['image'] for f in features])
    x2 = torch.stack([f['text'] for f in features])

    y = torch.stack([f['score'] for f in features])

    dataset = torch.utils.data.TensorDataset(x1, x2, y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

    # train the model
    model = MLP(768)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(10):
        for x1, x2, y in dataloader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            y_hat = model(x1)
            loss = torch.nn.functional.mse_loss(y_hat, y)
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch}, Loss: {loss.item()}")


    