import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"

import torch
import torch.nn as nn # nn is already imported, but DataParallel is part of it
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader # Keep Dataset, DataLoader
import os
import time
import copy
import matplotlib.pyplot as plt
# import json # No longer strictly needed if using folder names as class IDs
from PIL import Image # Keep for loading images
import glob # Keep for finding files
from tqdm import tqdm # Import tqdm


# --- Device Setup ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Custom Dataset Class for Training (Based on user's MultiFolderDataset) ---
class ImageNet100TrainDataset(Dataset):
    """
    Custom Dataset for ImageNet-100 training data spread across multiple folders (train.X1, train.X2, ...),
    assuming subdirectories within these folders are named by class ID (e.g., n01440764).
    """
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory containing train.X* folders.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_ids = set() # To collect unique class IDs (folder names)

        # Find all train directories (train.X1, train.X2, ...)
        train_dirs = sorted(glob.glob(os.path.join(self.root_dir, 'train')))
        if not train_dirs:
            print(f"Error: No 'train.X*' directories found in {self.root_dir}")
            raise FileNotFoundError("Training directories not found.")

        print(f"Found training directories: {train_dirs}")

        for folder in train_dirs:
            # List subdirectories (expected to be class IDs like 'n01440764')
            try:
                subdirs = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]
            except FileNotFoundError:
                print(f"Warning: Could not list directories in {folder}. Skipping.")
                continue

            if not subdirs:
                print(f"Warning: No subdirectories (class folders) found in {folder}. Skipping.")
                continue

            for label_id in subdirs:
                self.class_ids.add(label_id) # Add class ID to our set
                class_folder_path = os.path.join(folder, label_id)
                try:
                    # Find image files within the class subdirectory
                    files = glob.glob(os.path.join(class_folder_path, '*.JPEG'))
                    files.extend(glob.glob(os.path.join(class_folder_path, '*.jpg'))) # Add .jpg
                    files.extend(glob.glob(os.path.join(class_folder_path, '*.png'))) # Add .png
                except FileNotFoundError:
                     print(f"Warning: Error accessing files in {class_folder_path}. Skipping.")
                     continue

                # Add (image_path, class_id) tuples to samples
                self.samples.extend([(f, label_id) for f in files])

        if not self.samples:
            raise RuntimeError(f"No image samples found in any train.X* directories within {root_dir}")

        print(f"Found {len(self.samples)} training images across {len(self.class_ids)} classes.")

        # Create mapping from class ID (folder name) to integer index
        self.sorted_class_ids = sorted(list(self.class_ids))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.sorted_class_ids)}
        self.num_classes = len(self.sorted_class_ids)


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path, label_id = self.samples[idx]

        try:
            image = Image.open(img_path).convert('RGB') # Ensure image is RGB
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Consider returning a placeholder or raising a specific error
            raise RuntimeError(f"Could not load image: {img_path}") from e

        label_idx = self.class_to_idx[label_id] # Map class ID string to integer index

        if self.transform:
            try:
                image = self.transform(image)
            except Exception as e:
                print(f"Error applying transform to image {img_path}: {e}")
                raise RuntimeError(f"Could not transform image: {img_path}") from e

        return image, label_idx



# --- Training Function ---
def train_model(model, criterion, optimizer, num_epochs=25, dataloaders=None,
        dataset_sizes=None):
    since = time.time()
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    # Save the state dict of the original model, not the DataParallel wrapper
    best_model_wts = copy.deepcopy(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Wrap the dataloader with tqdm for a progress bar
            # The 'desc' argument provides a description for the progress bar
            dataloader_with_progress = tqdm(dataloaders[phase], desc=f'{phase} Epoch {epoch}/{num_epochs - 1}')


            # Iterate over data.
            for inputs, labels in dataloader_with_progress:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                # If using DataParallel, outputs might be a list of tensors,
                # but criterion usually handles this and returns a single loss.
                # torch.max works on the gathered output tensor.
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)

                # Update the progress bar description with current loss and accuracy
                dataloader_with_progress.set_postfix(loss=loss.item(), acc=torch.sum(preds == labels.data).item() / inputs.size(0))


            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model if it's the best validation accuracy
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())

                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    # Save the state dict of the original model, not the DataParallel wrapper
                    best_model_wts = copy.deepcopy(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict())
                    print(f'  New best validation accuracy: {best_acc:.4f}')

        print() # Newline after each epoch

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    # Load weights back into the original model if DataParallel was used
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(best_model_wts)
    else:
        model.load_state_dict(best_model_wts)

    return model, history


def main(model_name):

    # --- Configuration ---
    # Kaggle Input Directory
    data_dir = '/YOUR_ROOT_PATH/data/imagenet100/data'#'/kaggle/input/imagenet100'

    # Model parameters
    root_path = "/YOUR_ROOT_PATH/checkpoints/final/resnet_final"
    # num_classes will be determined from the dataset folders
    # num_classes = 100 # We'll set this dynamically later
    batch_size = 512*16
    num_epochs = 5
    learning_rate = 0.001
    momentum = 0.9

    # Output file names (will include dynamically determined num_classes)
    # These will be set after num_classes is determined
    output_model_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_best.pth'
    output_layer_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_fc_layer.pth'
    output_plot_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_metrics.png'

    os.makedirs(f'{root_path}/{model_name}/clean', exist_ok=True)


    # --- Data Preprocessing ---
    # Using the transforms from your snippet, slightly adjusted for train/val
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)), # Resize first
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15), # Reduced rotation slightly
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), # Added color jitter
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize(256), # Standard validation: resize larger then center crop
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
    }


    # --- Data Loading ---
    print("Initializing Datasets and Dataloaders...")

    try:
        # Training Dataset (Custom)
        train_dataset = ImageNet100TrainDataset(data_dir, transform=data_transforms['train'])
        num_classes = train_dataset.num_classes # Get number of classes from the dataset
        print(f"Determined number of classes: {num_classes}")

        # Validation Dataset (using ImageFolder as per user snippet)
        # IMPORTANT: This assumes '/kaggle/input/imagenet100/val.X' contains subdirectories named by class ID
        val_dir_path = os.path.join(data_dir, 'val')
        if not os.path.isdir(val_dir_path):
            raise FileNotFoundError(f"Validation directory '{val_dir_path}' not found. ImageFolder requires this.")

        val_dataset = datasets.ImageFolder(val_dir_path, transform=data_transforms['val'])

        # Verify class consistency (optional but recommended)
        if len(val_dataset.classes) != num_classes:
            print(f"Warning: Training set has {num_classes} classes, but validation set (ImageFolder) found {len(val_dataset.classes)} classes in {val_dir_path}.")
            # Decide how to handle inconsistency - here we trust the training set's count
        else:
            # Ensure the validation set's class_to_idx matches the training set's order
            # This is crucial if ImageFolder sorts classes differently
            # Note: This assumes the class folder names in val_dir_path are the same as in train_dataset
            val_dataset.class_to_idx = train_dataset.class_to_idx
            print("Validation classes seem consistent with training classes.")


        image_datasets = {'train': train_dataset, 'val': val_dataset}       


        dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True if x == 'train' else False, num_workers=4) for x in ['train', 'val']}
        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

        print(f"Training dataset size: {dataset_sizes['train']}")
        print(f"Validation dataset size: {dataset_sizes['val']}")


        output_model_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_best.pth'
        output_layer_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_fc_layer.pth'
        output_plot_name = f'{root_path}/{model_name}/clean/{model_name}_finetuned_imagenet100_metrics.png'


    except (FileNotFoundError, ValueError, KeyError, RuntimeError) as e:
        print(f"Error initializing dataset/dataloader: {e}")
        exit()
    except Exception as e:
        print(f"An unexpected error occurred during data loading: {e}")
        exit()

    # --- Model Loading and Modification ---
    print(f"Loading pre-trained model: {model_name}")

    if model_name == "resnet18":
        model_ft = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    elif model_name == "resnet34":
        model_ft = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
    elif model_name == "resnet50":
        model_ft = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    elif model_name == "resnet101":
        model_ft = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
    elif model_name == "resnet152":
        model_ft = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
    else:
        print("Invalid model name, exiting...")
        exit()

    print("Freezing base model parameters...")
    for param in model_ft.parameters():
        param.requires_grad = False

    num_ftrs = model_ft.fc.in_features
    print(f"Replacing the final layer for {num_classes} classes...")
    model_ft.fc = nn.Linear(num_ftrs, num_classes)

    # Move model to GPU
    model_ft = model_ft.to(device)

    # --- Apply DataParallel if multiple GPUs are available ---
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
        model_ft = nn.DataParallel(model_ft)
    else:
        print("Only one GPU available or using CPU, DataParallel not applied.")


    # --- Optimizer Setup ---
    # The optimizer needs to be defined AFTER the model is moved to device and wrapped by DataParallel
    # as it needs to operate on the parameters of the potentially wrapped model.
    params_to_update = []
    print("Parameters to train:")
    # If DataParallel is used, model_ft is a DataParallel object, and .named_parameters()
    # will correctly access the parameters of the underlying model (.module).
    for name, param in model_ft.named_parameters():
        if param.requires_grad:
            params_to_update.append(param)
            print(f"\t{name}")
    optimizer_ft = optim.SGD(params_to_update, lr=learning_rate, momentum=momentum)

    # --- Loss Function ---
    criterion = nn.CrossEntropyLoss()


    # --- Start Training ---
    print("Starting training...")
    try:
        model_ft, history = train_model(model_ft, criterion, optimizer_ft, num_epochs=num_epochs, dataloaders=dataloaders,
        dataset_sizes=dataset_sizes)
    except Exception as e:
        print(f"An error occurred during training: {e}")
        exit()

    # --- Save the fine-tuned model and the final layer ---
    print(f"Saving the best model state dict to {output_model_name}")
    # Save the state dict of the original model, not the DataParallel wrapper
    if isinstance(model_ft, nn.DataParallel):
        torch.save(model_ft.module.state_dict(), output_model_name)
        print(f"Saving the final layer state dict to {output_layer_name}")
        torch.save(model_ft.module.fc.state_dict(), output_layer_name) # Access fc via .module
    else:
        torch.save(model_ft.state_dict(), output_model_name)
        print(f"Saving the final layer state dict to {output_layer_name}")
        torch.save(model_ft.fc.state_dict(), output_layer_name)
    print("Models saved.")

    # --- Plotting Training History ---
    print("Plotting training metrics...")
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'bo-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.subplot(1, 2, 2)
    train_acc_float = [acc for acc in history['train_acc']]
    val_acc_float = [acc for acc in history['val_acc']]
    plt.plot(epochs, train_acc_float, 'bo-', label='Training Accuracy')
    plt.plot(epochs, val_acc_float, 'ro-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_plot_name)
    print(f"Plot saved to {output_plot_name}")
    plt.show()



if __name__ == "__main__":

    # Choose one: "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"
    for model_name in ["resnet18", "resnet34", "resnet101", "resnet152",]: # "resnet50",
        print(f'Strat training of {model_name}')
        main(model_name)
        torch.cuda.empty_cache()
