import torch
from PIL import Image
import os
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from lavis.models import load_model_and_preprocess
import torch.nn as nn
import torch.optim as optim
import time
from torch.nn.functional import normalize
from tqdm import tqdm
from modules import transform, contrastive_loss

device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
blip2_model, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_feature_extractor", model_type="pretrain", is_eval=False, device=device);
dataset_path = "../DATA/Caltech101/"


class MyDataset(Dataset):
    def __init__(self, image_paths, labels, root_dir, vis_transform=None, txt_transform = None):
        self.image_paths = image_paths
        self.labels = labels
        self.root_dir = root_dir
        self.vis_transform = vis_transform
        self.txt_transform = txt_transform


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = self.image_paths[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB') 

        if self.vis_transform:
            image = self.vis_transform(image)

        label = self.labels[idx]
        text = " "

        if self.txt_transform:
            text = self.txt_transform(text)

        return image, text, label

# Define the transforms for the data
train_image_transform =  transform.Transforms(size=224, s=0.5, mean = (0.48145466, 0.4578275, 0.40821073), std = (0.26862954, 0.26130258, 0.27577711), blur=True)
test_image_transform = transform.Transforms(size=224, s=0.5, mean = (0.48145466, 0.4578275, 0.40821073), std = (0.26862954, 0.26130258, 0.27577711)).test_transform
train_text_transform = txt_processors['train']
test_text_transform = txt_processors['eval']

class Network(nn.Module):
    def __init__(self, feature_extractor, feature_dim, class_num):
        super(Network, self).__init__()
        self.feature_extractor = feature_extractor
        self.feature_dim = feature_dim
        self.cluster_num = class_num
        self.instance_projector = nn.Sequential(
            #nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, self.feature_dim),
        )
        self.cluster_projector = nn.Sequential(
            #nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, self.cluster_num),
            nn.Softmax(dim=1)
        )

    def forward(self, x_i, x_j, text):
        h_i = self.feature_extractor({"image": x_i, "text_input": text})
        h_j = self.feature_extractor({"image": x_j, "text_input": text})

        z_i = normalize(self.instance_projector(h_i), dim=1)
        z_j = normalize(self.instance_projector(h_j), dim=1)

        c_i = self.cluster_projector(h_i)
        c_j = self.cluster_projector(h_j)

        return z_i, z_j, c_i, c_j

    def forward_cluster(self, x):
        h = self.feature_extractor(x)
        c = self.cluster_projector(h)
        c = torch.argmax(c, dim=1)
        return c
    

for param in blip2_model.parameters():
    param.requires_grad = False
blip2_model.ctx.requires_grad = True   

class_num = 200
feature_extractor = blip2_model
model = Network(blip2_model, 128, class_num)
model = model.to(device) 

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.size())

# Load the  data
training_batch_size = 256
train_image_paths = np.load(dataset_path + 'labels/train_image_pth.npy')
train_labels = np.load(dataset_path + 'labels/train_label.npy')
trainset = MyDataset(train_image_paths, train_labels, dataset_path, vis_transform = train_image_transform, txt_transform = train_text_transform)
trainloader = DataLoader(trainset, batch_size=training_batch_size, shuffle=True, num_workers=2)

val_image_paths = np.load(dataset_path + 'labels/val_image_pth.npy')
val_labels = np.load(dataset_path + 'labels/val_label.npy')
valset = MyDataset(val_image_paths, val_labels, dataset_path, vis_transform = train_image_transform, txt_transform = test_text_transform)
valloader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)

test_image_paths = np.load(dataset_path + 'labels/test_image_pth.npy')
test_labels = np.load(dataset_path + 'labels/test_label.npy')
testset = MyDataset(test_image_paths, test_labels, dataset_path, vis_transform = test_image_transform, txt_transform = test_text_transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay= 0.)
instance_temperature = 0.5
cluster_temperature = 1.0
criterion_instance = contrastive_loss.InstanceLoss(training_batch_size, instance_temperature, device).to(device)
criterion_cluster = contrastive_loss.ClusterLoss(class_num, cluster_temperature, device).to(device)

num_epochs = 150
best_val_acc = 0.0
best_val_loss = 10
for epoch in range(num_epochs):
    train_loss = 0.0
    
    model.train()
    train_progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")

    for inputs, text, _ in train_progress_bar:
        
        optimizer.zero_grad()
        x_i = inputs[0].to(device)
        x_j = inputs[1].to(device)
        
        z_i, z_j, c_i, c_j = model(x_i, x_j, text)
        
        criterion_instance = contrastive_loss.InstanceLoss(x_i.shape[0], instance_temperature, device).to(device)

        loss_instance = criterion_instance(z_i, z_j)
        loss_cluster = criterion_cluster(c_i, c_j)
        
        loss = loss_instance + loss_cluster
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs[0].size(0)
        train_progress_bar.set_postfix({"loss": loss.item()})

    train_loss /= len(trainset)
    if epoch % 25 == 0:
        torch.save(model.state_dict(), dataset_path + "model_param_" + str(epoch) + "_" + str(round(train_loss, 4)) + ".pth")