import numpy as np
import random
import sys
import os
import json
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import copy
import pickle
import time
import shutil

'''
    This script pre-computes the model softmax, model embedding, and domain classifier softmax for the validation set.
'''

BATCH_SIZE = 32
SEED = 10
LEVEL = '3'
NUM_CLASSES = '3'
MDTS_DIR = f'MDTS_regressor'
TEST_DIR = 'data/val_data'
DOMAIN_CLASSIFIER_DIR = f'best_domain_classifiers_{LEVEL}_{NUM_CLASSES}'


class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Get the label
        label = self.labels[idx]

        return img_path.split('/')[-1], image, label


device = torch.device("cuda" if torch.cuda.is_available() else "mps")

for arch in ['vit', 'resnet50', 'clip']:

    # create the model
    if arch == 'vit':
        model = timm.create_model('vit_base_patch16_224', pretrained=True)
        domain_classifier = timm.create_model('vit_base_patch16_224', pretrained=True)
        embedding_model = timm.create_model('vit_base_patch16_224', num_classes=0, pretrained=True)
    elif arch == 'clip':
        model = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', pretrained=True)
        domain_classifier = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', pretrained=True)
        embedding_model = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', num_classes=0, pretrained=True)
    elif arch == 'resnet50':
        model = resnet50(weights=ResNet50_Weights.DEFAULT)
        domain_classifier = resnet50(weights=ResNet50_Weights.DEFAULT)
        embedding_model = resnet50(weights=ResNet50_Weights.DEFAULT)
        embedding_model.fc = torch.nn.Identity()
    model.eval()
    embedding_model.eval()


    # read domain and subclass info from file
    with open(os.path.join(DOMAIN_CLASSIFIER_DIR, arch, "domain_name_to_label.json")) as json_file:
        domain_to_label = json.load(json_file)
    with open(os.path.join(DOMAIN_CLASSIFIER_DIR, arch, "subclass_ratio.json")) as json_file:
        subclass_ratio = json.load(json_file)

    subclass_to_superclass = {}
    subclass_to_label = {}
    for domain in subclass_ratio:
        for subclass in subclass_ratio[domain]:
            subclass_to_superclass[subclass] = domain
            subclass_to_label[subclass] = domain_to_label[domain]
    K = len(subclass_ratio)

    if arch == 'resnet50':
        num_ftrs = domain_classifier.fc.in_features
        
        domain_classifier.fc = nn.Sequential(
            nn.Linear(num_ftrs, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, K),
        )
        transform = ResNet50_Weights.DEFAULT.transforms()
    else:
        num_ftrs = domain_classifier.head.in_features
        
        domain_classifier.head = nn.Sequential(
            nn.Linear(num_ftrs, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, K),
        )
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])        
    
    print(f"feature size {num_ftrs}")
    
    domain_classifier_model_dir = os.path.join(DOMAIN_CLASSIFIER_DIR, arch, "domain_classifier.pth")
    domain_classifier.load_state_dict(torch.load(domain_classifier_model_dir, map_location=device))
    domain_classifier.eval()

    # load validation image names
    file_list = []
    for subclass in os.listdir(TEST_DIR):
        if not subclass.startswith('n'):
            continue
        folder_path = os.path.join(TEST_DIR, subclass)
        file_list += [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.png', '.JPEG'))]
    dummy_labels = [0] * len(file_list)

    # load validation image names for domain classifier
    file_list_domain_classifier = []
    for domain in subclass_ratio:
        for subclass in subclass_ratio[domain]:
            folder_path = os.path.join(TEST_DIR, subclass)
            file_list_domain_classifier += [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg', '.png', '.JPEG'))]
    dummy_labels_domain_classifier = [0] * len(file_list_domain_classifier)

    dataset = ImageDataset(image_paths=file_list, labels=dummy_labels, transform=transform)
    val_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    dataset_domain_classifier = ImageDataset(image_paths=file_list_domain_classifier, labels=dummy_labels_domain_classifier, transform=transform)
    val_loader_domain_classifier = torch.utils.data.DataLoader(dataset_domain_classifier, batch_size=BATCH_SIZE, shuffle=False)


    model.to(device)
    domain_classifier.to(device)
    embedding_model.to(device)

    # compute predictions if first architecture
        
    prediction_list = {}
    embedding_list = {}
    weights_list = {}
    batch_num = 1
    for img_path, img, label in val_loader:
        print(f"batch {batch_num}/{len(val_loader)}")
        batch_num += 1
        img = img.to(device)
        with torch.no_grad():
            prediction = torch.nn.functional.softmax(model(img), dim=1).cpu().detach().numpy()
            embedding = embedding_model(img).cpu().detach().numpy()
            # weights = torch.nn.functional.softmax(domain_classifier(img), dim=1).cpu().detach().numpy()
        
        for i in range(len(img_path)):
            prediction_list[img_path[i]] = prediction[i]
            embedding_list[img_path[i]] = embedding[i]
            # weights_list[img_path[i]] = weights[i]
    for img_path, img, label in val_loader_domain_classifier:
        print(f"batch {batch_num}/{len(val_loader_domain_classifier)}")
        batch_num += 1
        img = img.to(device)
        with torch.no_grad():
            weights = torch.nn.functional.softmax(domain_classifier(img), dim=1).cpu().detach().numpy()

        for i in range(len(img_path)):
            weights_list[img_path[i]] = weights[i]
    print(len(prediction_list))
    print(len(embedding_list))
    print(len(weights_list))
    input()
    
    with open(os.path.join(TEST_DIR, f"predictions_{arch}.pkl"), "wb") as f:
        pickle.dump(prediction_list, f)
    with open(os.path.join(TEST_DIR, f"embeddings_{arch}.pkl"), "wb") as f:
        pickle.dump(embedding_list, f)
    with open(os.path.join(TEST_DIR, f"weights_{arch}_{LEVEL}_{NUM_CLASSES}.pkl"), "wb") as f:
        pickle.dump(weights_list, f)
