import os

# from jupyter_server.services.contents import checkpoints

import clip
import torch
from PIL import Image
import numpy as np

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from dataset_diy.modules.utils import getTrainData

BATCHE_SIZE = 64

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("RN50x64", device=device, jit=False)
print("download model success")
writer = SummaryWriter("./logs")

class StereotypeDataset(Dataset):
    def __init__(self, image_list, prompt_list, stereotype_list):
        self.image_path = image_list
        self.prompt = clip.tokenize(prompt_list)
        self.stereotype = clip.tokenize(stereotype_list)

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]))
        prompt = self.prompt[idx]
        stereotype = self.stereotype[idx]
        return image, prompt, stereotype


def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad is not None:
            p.grad.data = p.grad.data.float()


if device == "cpu":
    convert_models_to_fp32(model)
else:
    clip.model.convert_weights(model)


def trainClip(Epoch, data_loader, model):
    loss = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

    for epoch in range(Epoch):
        for batch in data_loader:
            image, prompt, stereotype = batch
            image = image.to(device)
            prompt = prompt.to(device)
            stereotype = stereotype.to(device)

            image_logits, prompt_logits, stereotype_logits, prompt_stereotype = model(image, prompt, stereotype)
            ground_truth = torch.arange(len(image), dtype=torch.long, device=device)

            image_loss = loss(image_logits, ground_truth)
            prompt_loss = loss(prompt_logits, ground_truth)
            stereotype_loss = loss(stereotype_logits, ground_truth)
            total_loss = (image_loss + prompt_loss + stereotype_loss) / 3

            optimizer.zero_grad()
            total_loss.backward()

            if device == "cpu":
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)

                writer.add_scalar("stereotype_loss", stereotype_loss, epoch)
                print("Epoch: {}, image_loss: {}, prompt_loss: {}, stereotype_loss: {}, total_loss: {}".format(epoch,
                                                                                                               image_loss,
                                                                                                               prompt_loss,
                                                                                                               stereotype_loss,
                                                                                                               total_loss))

        # torch.save({
        #     'epoch': epoch,
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'image_loss': image_loss,
        #     'prompt_loss': prompt_loss,
        #     'stereotype_loss': stereotype_loss
        # }, './model.pth')


# def load_model(model_path):
#     model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
#
#     checkpoints['model_state_dict']['input_resolution'] = model.input_resolution
#     checkpoints['model_state_dict']['context_length'] = model.context_length
#     checkpoints['model_state_dict']['vocab_size'] = model.vocab_size
#
#     model.load_state_dict(checkpoints['model_state_dict'])


list_prompt, list_image_path, list_stereotype = getTrainData("../dataset_diy/data/2024-01-17-sdxl/data_label.csv")
print(list_image_path, list_prompt, list_stereotype)

dataset = StereotypeDataset(list_image_path, list_prompt, list_stereotype)
data_loader = DataLoader(dataset, batch_size=BATCHE_SIZE)
#
trainClip(1000, data_loader, model)
