import numpy as np
from robustness.tools.breeds_helpers import setup_breeds, ClassHierarchy
import matplotlib
import random
import sys
from torchvision.models import resnet50, ResNet50_Weights

import os
import json
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image

import time

ARCH_TYPE = sys.argv[1]  # 'vit', 'resnet50', or 'clip'
BATCH_SIZE = 32                         # batch size for training
NUM_EPOCHS = 100                        # number of epochs to train the model   
LEVEL = int(sys.argv[2])                # level of the class hierarchy for BREEDS
TRAIN_RATIO = 0.8                       # ratio of training data    
MIN_SUBCLASSES = int(sys.argv[3])       # number of subclasses in a domain
LR = 1e-6                               # learning rate for the optimizer  
MAIN_DIR = ''                                                                   # path to the main directory
TRAIN_DIR = os.path.join(MAIN_DIR, 'data/train')                                # path to the training data
INFO_DIR = os.path.join(MAIN_DIR, 'imagenet_class_hierarchy/modified/')         # path to the class hierarchy information. Download it from BREEDS GitHub repo
SUBCLASS_INFO = os.path.join(MAIN_DIR, f'best_domain_classifiers_{LEVEL}_{MIN_SUBCLASSES}/ref')          # path to the reference subclass information to keep the subclasses for each architecture the same
SAVE_DIR = os.path.join(MAIN_DIR, f'best_domain_classifiers_{LEVEL}_{MIN_SUBCLASSES}/{ARCH_TYPE}')     # path to save the trained model

if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Get the label
        label = self.labels[idx]

        return image, label

def sample_images(subdir_path, ratio):
    # Get list of all image file paths in the subdirectory
    all_images = [os.path.join(subdir_path, f) for f in os.listdir(subdir_path) if f.endswith(('.jpg', '.png', '.JPEG'))]
    
    # Calculate the number of images to sample
    sample_size = int(len(all_images) * ratio)
    
    # Randomly sample the images
    sampled_images = random.sample(all_images, sample_size)
    return sampled_images

def label_from_path(file_path, subclass_to_superclass):
    path = file_path.split("/")
    return subclass_to_superclass[path[-2]]

# Define the training function
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        print(f"epoch {epoch+1}/{num_epochs}")
        num_batch = len(train_loader)
        batch_num = 1
        
        start_time = time.time()
        
        for images, labels in train_loader:
            print(f"{batch_num}/{num_batch}")
            batch_num += 1
            start_time = time.time()
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track metrics
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        train_accuracy = correct / total * 100

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        best_val_loss = float('inf')

        with torch.no_grad():
            batch_num = 1
            for images, labels in val_loader:
                if batch_num == 10:
                    break
                print(f"val {batch_num}/{num_batch}")
                batch_num += 1
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total * 100

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"Best validation loss: {best_val_loss:.4f}")
            # save the model
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"domain_classifier.pth"))

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {epoch_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")
        torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"model_{epoch}_{val_loss:.3f}_{val_accuracy:.3f}.pth"))
    return model


# read class information 
foldername_to_class = {}    # map label name to label index
class_to_foldername = {}    # map label index to label name
with open("imagenet_class_hierarchy/modified/dataset_class_info.json", "r") as file:
    dataset_class_info = json.load(file)
for info in dataset_class_info:
    foldername_to_class[info[1]] = info[0]
    class_to_foldername[info[0]] = info[1]
if not (os.path.exists(INFO_DIR) and len(os.listdir(INFO_DIR))):
    print("Downloading class hierarchy information into `INFO_DIR`")
    setup_breeds(INFO_DIR)

# Defining superclass
hier = ClassHierarchy(INFO_DIR)
superclasses = hier.get_nodes_at_level(LEVEL)
print(f"{len(superclasses)} superclasses at level {LEVEL}\n")

subclass_to_superclass = {}
chosen_superclasses = []
chosen_superclasses_names = []
chosen_subclasses = {}
for i in range(len(superclasses)):
    superclass_list = list(superclasses)[i]
    superclass_name = hier.HIER_NODE_NAME[superclass_list]
    subclasses = hier.leaves_reachable(superclass_list)
    print(f"{len(subclasses)} subclasses for superclass {superclass_name}")
    if len(subclasses) < MIN_SUBCLASSES:
        continue
    chosen_superclasses.append(superclass_list)
    chosen_superclasses_names.append(superclass_name)
    chosen_subclasses[superclass_name] = random.sample(list(subclasses), k=MIN_SUBCLASSES)
    for subclass in chosen_subclasses[superclass_name]:
        subclass_to_superclass[subclass] = superclass_name
print(f"{len(chosen_subclasses)} domains with {MIN_SUBCLASSES} classes each")
distribution = [1 for _ in range(len(chosen_superclasses))]

subclass_ratio = {}
for i in range(len(chosen_superclasses)):
    superclass = chosen_superclasses[i]

    superclass_name = chosen_superclasses_names[i]
    subclasses = chosen_subclasses[superclass_name]
    print(f"number of subclasses in {superclass_name}: {len(subclasses)}")

    subclass_ratio_list = {}

    for subclass in subclasses:
        subclass_ratio_list[subclass] = distribution[i]
    subclass_ratio[superclass_name] = subclass_ratio_list

domain_name_to_label = {domain_name: idx for idx, domain_name in enumerate(subclass_ratio.keys())}

# save subclass_ratio and domain_name_to_label as json if directory does not exist
if not os.path.exists(SUBCLASS_INFO):
    os.makedirs(SUBCLASS_INFO)
    with open(os.path.join(SUBCLASS_INFO, 'subclass_ratio.json'), 'w') as f:
        json.dump(subclass_ratio, f)
    with open(os.path.join(SUBCLASS_INFO, 'domain_name_to_label.json'), 'w') as f:
        json.dump(domain_name_to_label, f)
else:       # load the reference subclass_ratio and domain_name_to_label 
    with open(os.path.join(SUBCLASS_INFO, 'subclass_ratio.json')) as json_file:
        subclass_ratio = json.load(json_file)
    with open(os.path.join(SUBCLASS_INFO, 'domain_name_to_label.json')) as json_file:
        domain_name_to_label = json.load(json_file)
with open(os.path.join(SAVE_DIR, 'subclass_ratio.json'), 'w') as f:
    json.dump(subclass_ratio, f)
with open(os.path.join(SAVE_DIR, 'domain_name_to_label.json'), 'w') as f:
    json.dump(domain_name_to_label, f)

subclass_to_superclass = {}
for domain in subclass_ratio:
    for subclass in subclass_ratio[domain]:
        subclass_to_superclass[subclass] = domain

K = len(domain_name_to_label)
# Map subdirectory names to integer labels
selected_images = []
rng = np.random.default_rng()
for domain in subclass_ratio.keys():
    # Collect samples for each class directory based on the defined ratios
    for subdir, ratio in subclass_ratio[domain].items():
        subdir_path = os.path.join(TRAIN_DIR, subdir)
        selected_images.extend(sample_images(subdir_path, ratio))
print(f"{len(selected_images)} total images chosen")
rng.shuffle(selected_images)

# Create a dataset from the file paths
labels = [domain_name_to_label[label_from_path(file_path, subclass_to_superclass)] for file_path in selected_images]

if ARCH_TYPE == 'resnet50':
    transform = ResNet50_Weights.DEFAULT.transforms()
else:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

dataset = ImageDataset(image_paths=selected_images, labels=labels, transform=transform)
train_dataset, val_dataset = random_split(dataset, [int(TRAIN_RATIO * len(dataset)), len(dataset) - int(TRAIN_RATIO * len(dataset))])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# create the model
if ARCH_TYPE == 'vit':
    model = timm.create_model('vit_base_patch16_224', pretrained=True)
elif ARCH_TYPE == 'clip':
    model = timm.create_model('vit_base_patch16_clip_224.openai_ft_in1k', pretrained=True)
elif ARCH_TYPE == 'resnet50':
    model = resnet50(weights=ResNet50_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False 

if ARCH_TYPE == 'resnet50':
    num_ftrs = model.fc.in_features
    
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, K),
    )
else:
    num_ftrs = model.head.in_features
    
    model.head = nn.Sequential(
        nn.Linear(num_ftrs, 2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, K),
    )

# train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
model.to(device)
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=NUM_EPOCHS)
torch.save(trained_model, "entire_model.pth")
torch.save(trained_model.state_dict(), os.path.join(SAVE_DIR, f"model_final.pth"))