import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
import copy
import numpy as np
import os
import json
import timm
import sys
from PIL import Image

'''
    This script creates the features required for Multi-Domain Temperature Scaling (MDTS).
'''

########################################################################################

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Get the label
        label = self.labels[idx]

        return image, label

################################################################################

# the domain classifier with corresponding architecture, level, and number of classes must be trained first
ARCH_TYPE = sys.argv[1]         # 'vit', 'resnet50', 'clip'
LEVEL = int(sys.argv[2])        # level of the class hierarchy for BREEDS
NUM_CLASSES = int(sys.argv[3])  # number of subclasses in a domain

main_dir = ''
MDTS_dir = os.path.join(main_dir, 'data/MDTS_data')     # directory where the MDTS data is stored
model_dir = os.path.join(main_dir, f'best_domain_classifiers_{LEVEL}_{NUM_CLASSES}/{ARCH_TYPE}')    # directory where the domain classifier is stored
save_dir = os.path.join(main_dir, f'/MDTS_features/{ARCH_TYPE}/{LEVEL}_{NUM_CLASSES}')              # directory where the features will be saved
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# load the subclass ratios
with open(os.path.join(model_dir, "subclass_ratio.json")) as json_file:
    subclass_ratio = json.load(json_file)
# load the domain name to label
with open(os.path.join(model_dir, "domain_name_to_label.json")) as json_file:
    domain_to_label = json.load(json_file)
K = len(subclass_ratio)
print(f"{K} domains")
subclass_to_label = {}
for domain in subclass_ratio:
    for subclass in subclass_ratio[domain]:
        subclass_to_label[subclass] = domain_to_label[domain]
# create dataloader
selected_images = []
selected_labels = []
i = 1
for folder in subclass_to_label:
    folder_path = os.path.join(MDTS_dir, folder)
    all_filenames = os.listdir(folder_path)
    num_samples = len(all_filenames)
    i += 1
    selected_images += [os.path.join(folder_path, filename) for filename in all_filenames[:num_samples]]
    selected_labels += [subclass_to_label[folder]] * num_samples

if ARCH_TYPE == 'resnet50':
    transform = ResNet50_Weights.DEFAULT.transforms()
else:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

dataset = ImageDataset(image_paths=selected_images, labels=selected_labels, transform=transform)
MDTS_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# load the models

# create the model
if ARCH_TYPE == 'vit':
    model = timm.create_model('vit_base_patch16_224', pretrained=True)
    model1 = timm.create_model('vit_base_patch16_224', num_classes=0, pretrained=True)
elif ARCH_TYPE == 'clip':
    model = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', pretrained=True)
    model1 = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', num_classes=0, pretrained=True)
elif ARCH_TYPE == 'resnet50':
    model = resnet50(weights=ResNet50_Weights.DEFAULT)
    model1 = resnet50(weights=ResNet50_Weights.DEFAULT)
    model1.fc = torch.nn.Identity()

if ARCH_TYPE == 'resnet50':
    feature_size = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(feature_size, 2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, K),
    )
    model.load_state_dict(torch.load(os.path.join(model_dir, "domain_classifier.pth"), weights_only=True))
else:
    feature_size = model.head.in_features
    model.head = nn.Sequential(
        nn.Linear(feature_size, 2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, K),
    )
    model.load_state_dict(torch.load(os.path.join(model_dir, "domain_classifier.pth"), weights_only=True))
model.eval()
model1.eval()
print(f"feature size: {feature_size}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create features
Z_ind = np.empty((0,feature_size), dtype=float)
Y_ind = np.empty((0,), dtype=float)
G_ind = np.empty((0,), dtype=float)
Logits_ind = np.empty((0,K), dtype=float)
Softmax_output_ind = np.empty((0,K), dtype=float)
batch_num = 0
with torch.no_grad():
    print("creating features")
    print(f"device: {device}")
    model.to(device)
    model1.to(device)
    for inputs, labels in MDTS_loader:
        batch_num += 1
        inputs, labels = inputs.to(device), labels.to(device)
        print(f"batch {batch_num}/{len(MDTS_loader)}")
        features = model1(inputs).cpu().numpy()
        logits = model(inputs).cpu().numpy()
        Z_ind = np.vstack((Z_ind, features))
        Logits_ind = np.vstack((Logits_ind, logits))
        label_list = labels.cpu().numpy()
        Y_ind = np.hstack((Y_ind, label_list))
        G_ind = np.hstack((G_ind, label_list))

        Softmax_output_ind = np.vstack((Softmax_output_ind, nn.Softmax(dim=1)(model(inputs)).cpu().numpy()))

# save to file
np.save(os.path.join(save_dir, 'Z_ind.npy'), Z_ind)
np.save(os.path.join(save_dir, 'Y_ind.npy'), Y_ind)
np.save(os.path.join(save_dir, 'G_ind.npy'), G_ind)
np.save(os.path.join(save_dir, 'Logits_ind.npy'), Logits_ind)
np.save(os.path.join(save_dir, 'Softmax_output_ind.npy'), Softmax_output_ind)