import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import math
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
import random
from imblearn.over_sampling import SMOTE

# Load Imagenette dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 pixels
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # Normalize with ImageNet means and stds
])

imagenette_train = datasets.Imagenette(root='./data', split='train', download=False, transform=transform)
imagenette_test = datasets.Imagenette(root='./data', split='val', download=False, transform=transform)
batch_size = 100
k = 10  # Number of classes
unbalance_factors = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]

# Initialize list to hold data for each class
class_data = [[] for _ in range(k)]
for data, label in imagenette_train:
    class_data[label].append(data)

# Make the classes unbalanced based on the defined unbalance factors
unbalanced_class_data = []
for i in range(k):
    total_samples = len(class_data[i])
    # Calculate the number of samples to keep for the class based on the unbalance factor
    samples_to_keep = int(total_samples * unbalance_factors[i])
    # Randomly select the subset of data
    selected_data = random.sample(class_data[i], samples_to_keep)
    unbalanced_class_data.append(torch.stack(selected_data))

# Create TensorDatasets for each class
tensor_datasets = [TensorDataset(data, torch.full((data.size(0),), label, dtype=torch.long)) for label, data in enumerate(unbalanced_class_data)]
imagenette_train = torch.utils.data.ConcatDataset(tensor_datasets)

# Create a DataLoader for combined training data
train_loader = DataLoader(imagenette_train, batch_size=batch_size, shuffle=True)

# Create a DataLoader for the test data
test_loader = DataLoader(imagenette_test, batch_size=batch_size, shuffle=False)

# Create a Dataloader for SMOTE
def extract_data_labels(concat_dataset):
    data_list = []
    labels_list = []
    for data, label in concat_dataset:
        data_list.append(data)
        labels_list.append(label)
    X = np.array(torch.cat(data_list, dim=0))
    y = np.array(torch.cat(labels_list, dim=0))
    return X, y

X_train_flat, y_train_flat = extract_data_labels(train_loader)
X_train_flat = X_train_flat.reshape(len(X_train_flat), -1)
smote = SMOTE()
X_smote, y_smote = smote.fit_resample(X_train_flat, y_train_flat)
X_smote_tensor = torch.tensor(X_smote, dtype=torch.float32).view(-1, 3, 224, 224)
y_smote_tensor = torch.tensor(y_smote, dtype=torch.long)
smote_dataset = TensorDataset(X_smote_tensor, y_smote_tensor)
smote_loader = DataLoader(smote_dataset, batch_size=batch_size, shuffle=True)

# Shuffle indices at the start of each epoch
def shuffle_class_data(tensor_datasets):
    shuffled_indices = []
    for dataset in tensor_datasets:
        num_samples = dataset.tensors[0].size(0)
        # Shuffle indices for the current class
        shuffled_indices.append(torch.randperm(num_samples))
    return shuffled_indices

# Define a ResNet model for classification on Imagenette
class ResNetImagenette(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetImagenette, self).__init__()
        # Load the pre-trained ResNet50 model
        self.resnet = models.resnet18()
        
        # Replace the final fully connected layer
        # ResNet18's final layer is (2048 -> 1000), we replace it with (2048 -> num_classes)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Define Focal Loss as per the original paper
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = torch.nn.functional.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)  # p_t is the probability of the true class
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean': return focal_loss.mean()
        elif self.reduction == 'sum': return focal_loss.sum()
        else: return focal_loss

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Compute class frequencies across all datasets
def compute_class_counts(tensor_datasets):
    class_counts = {}
    for dataset in tensor_datasets:
        _, labels = dataset.tensors
        labels = labels.numpy()
        for label in labels: class_counts[label] = class_counts.get(label, 0) + 1
    return class_counts

class_counts = compute_class_counts(tensor_datasets)
num_classes = len(class_counts)
classes = sorted(class_counts.keys())

# Calculate class weights (alpha) inversely proportional to class frequency
def compute_class_weights(class_counts):
    counts = np.array([class_counts[c] for c in classes], dtype=np.float32)
    weights = 1.0 / (counts + 1e-6) # Inverse frequency
    weights = weights / np.sum(weights) * num_classes # Normalize weights
    return torch.tensor(weights, dtype=torch.float32).to(device)

class_weights = compute_class_weights(class_counts)

# Initialize Focal Loss with computed class weights and gamma
gamma = 2  # As per the original paper
focal_criterion = FocalLoss(alpha=class_weights, gamma=gamma, reduction='mean').to(device)

criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for classification
N_total = sum(len(dataset) for dataset in tensor_datasets)

# Initialize training and test losses
epochs = 60
iterations = 5
M = int(np.floor(N_total / batch_size))  # Number of training batches
train_errors = np.zeros((epochs, iterations))
classical_errors = np.zeros((epochs, iterations))
balanced_errors = np.zeros((epochs, iterations))
smote_errors = np.zeros((epochs, iterations))
test_errors = np.zeros((epochs, iterations))
classical_test_errors = np.zeros((epochs, iterations))
balanced_test_errors = np.zeros((epochs, iterations))
smote_test_errors = np.zeros((epochs, iterations))

for iters in range(iterations):
    # Initialize the model and optimizer
    model = ResNetImagenette().to(device)
    model_l2m = ResNetImagenette().to(device)
    model_cbl = ResNetImagenette().to(device)
    model_smote = ResNetImagenette().to(device)
    model.load_state_dict(model_l2m.state_dict())
    model_cbl.load_state_dict(model_l2m.state_dict())
    model_smote.load_state_dict(model_l2m.state_dict())
    optimizer = optim.Adam(model.parameters(), lr=1e-6)
    optimizer_l2m = optim.Adam(model_l2m.parameters(), lr=1e-6)
    optimizer_cbl = optim.Adam(model_cbl.parameters(), lr=1e-6)
    optimizer_smote = optim.Adam(model_smote.parameters(), lr=1e-6)
    alpha = torch.tensor([len(dataset) / N_total for dataset in tensor_datasets], dtype=torch.float32).to(device)
    classical_alpha = torch.tensor([len(dataset) / N_total for dataset in tensor_datasets], dtype=torch.float32).to(device)
    alpha_lr = 1e-1

    # Training loop
    for epoch in range(epochs):
        correct_old = 0; correct_l2m = 0; correct_cbl = 0; correct_smote = 0; total = 0; 
        correct_test_old = 0; correct_test_l2m = 0; correct_test_cbl = 0; correct_test_smote = 0; total_test = 0
        prev_indices = np.zeros(k, dtype=int)
        tracking_error = np.zeros(k, dtype=float)
        tracking_error_cbl = np.zeros(k, dtype=float)
        shuffled_indices = shuffle_class_data(tensor_datasets)
        
        # Learn2Mix Training
        model_l2m.train()
        for j in range(M):  # For each batch in an epoch
            combined_data = []; combined_labels = []
            for i in range(k):  # For each dataset
                num_samples = max(int(alpha[i].item() * batch_size), 0)  # Number of samples from dataset 'i'
                data_list = []; labels_list = []; remaining = num_samples
                while remaining > 0:
                    available = len(tensor_datasets[i]) - prev_indices[i]
                    take = min(available, remaining)
                    indices = shuffled_indices[i][prev_indices[i]:prev_indices[i] + take]
                    data_list.append(tensor_datasets[i].tensors[0][indices].to(device))
                    labels_list.append(tensor_datasets[i].tensors[1][indices].to(device))
                    
                    # Update the remaining samples and the previous index, and wrap around if the end of the dataset is reached
                    remaining -= take; prev_indices[i] += take
                    if prev_indices[i] >= len(tensor_datasets[i]): prev_indices[i] = 0  # Reset to start
                
                # Concatenate all collected data and labels
                combined_data.append(torch.cat(data_list, dim=0))
                combined_labels.append(torch.cat(labels_list, dim=0))
                
            # Inner optimization step: update model parameters
            optimizer_l2m.zero_grad()
            x_mixed = torch.cat(combined_data, dim=0)
            y_mixed = torch.cat(combined_labels, dim=0)
            z_mixed = model_l2m(x_mixed)
            loss = criterion(z_mixed, y_mixed)      
            loss.backward()
            optimizer_l2m.step()

            # Compute class-wise losses
            start_idx = 0
            for i in range(k):
                num_samples = max(int(alpha[i].item() * batch_size), 0)  # Number of samples from dataset 'i'
                tracking_error[i] = criterion(z_mixed[start_idx:start_idx + num_samples], y_mixed[start_idx:start_idx + num_samples]).item()
                start_idx += num_samples
        
            # Outer optimization step: update alpha based on dataset-specific losses
            rewards = tracking_error / np.sum(tracking_error) if np.sum(tracking_error) > 0 else np.ones(k) / k
            alpha += alpha_lr * (torch.tensor(rewards, dtype=torch.float32).to(device) - alpha)

        # Class-balanced focal loss training
        model_cbl.train()
        for X_train, y_train in train_loader:
            X_train = X_train.to(device)
            y_train = y_train.to(device)
            # Forward pass and optimization with Focal Loss
            optimizer_cbl.zero_grad()
            z_cbl = model_cbl(X_train)
            loss_cbl = focal_criterion(z_cbl, y_train)
            loss_cbl.backward()
            optimizer_cbl.step()

        # SMOTE-based NN Training
        model_smote.train()
        smote_iterator = iter(smote_loader)
        for j in range(M):  # Ensure same number of batches
            try: X_train, y_train = next(smote_iterator)
            except StopIteration:
                smote_iterator = iter(smote_loader)
                X_train, y_train = next(smote_iterator)
            X_train = X_train.to(device)
            y_train = y_train.to(device)
            optimizer_smote.zero_grad()
            outputs_smote = model_smote(X_train)
            loss_smote = criterion(outputs_smote, y_train)
            loss_smote.backward()
            optimizer_smote.step()

        # Classical NN Training
        model.train()
        for X_train, y_train in train_loader:
            X_train = X_train.to(device)
            y_train = y_train.to(device)
            optimizer.zero_grad()
            classical_outputs = model(X_train)
            loss = criterion(classical_outputs, y_train)
            loss.backward()
            optimizer.step()

        # Computing Train Accuracy
        model.eval(); model_l2m.eval(); model_cbl.eval(); model_smote.eval()
        with torch.no_grad():
            for X_train, y_train in train_loader:
                X_train = X_train.to(device)
                y_train = y_train.to(device)
                train_outputs = model_l2m(X_train)
                classical_outputs = model(X_train)
                balanced_outputs = model_cbl(X_train)
                smote_outputs = model_smote(X_train)

                _, predicted_l2m = torch.max(train_outputs.data, 1)
                _, predicted_old = torch.max(classical_outputs.data, 1)
                _, predicted_cbl = torch.max(balanced_outputs.data, 1)
                _, predicted_smote = torch.max(smote_outputs.data, 1)

                total += y_train.size(0)
                correct_l2m += (predicted_l2m == y_train).sum().item()
                correct_old += (predicted_old == y_train).sum().item()
                correct_cbl += (predicted_cbl == y_train).sum().item()
                correct_smote += (predicted_smote == y_train).sum().item()

        # Computing Test Accuracy
        model.eval(); model_l2m.eval(); model_cbl.eval(); model_smote.eval()
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test = X_test.to(device)
                y_test = y_test.to(device)
                test_outputs = model_l2m(X_test)
                classical_outputs = model(X_test)
                balanced_outputs = model_cbl(X_test)
                smote_outputs = model_smote(X_test)

                _, predicted_l2m = torch.max(test_outputs.data, 1)
                _, predicted_old = torch.max(classical_outputs.data, 1)
                _, predicted_cbl = torch.max(balanced_outputs.data, 1)
                _, predicted_smote = torch.max(smote_outputs.data, 1)

                total_test += y_test.size(0)
                correct_test_l2m += (predicted_l2m == y_test).sum().item()
                correct_test_old += (predicted_old == y_test).sum().item()
                correct_test_cbl += (predicted_cbl == y_test).sum().item()
                correct_test_smote += (predicted_smote == y_test).sum().item()

        train_errors[epoch, iters] = 100 * correct_l2m / total
        classical_errors[epoch, iters] = 100 * correct_old / total
        balanced_errors[epoch, iters] = 100 * correct_cbl / total
        smote_errors[epoch, iters] = 100 * correct_smote / total
        test_errors[epoch, iters] = 100 * correct_test_l2m / total_test
        classical_test_errors[epoch, iters] = 100 * correct_test_old / total_test
        balanced_test_errors[epoch, iters] = 100 * correct_test_cbl / total_test
        smote_test_errors[epoch, iters] = 100 * correct_test_smote / total_test

        if epoch % 1 == 0:
            print(f"Iter {iters}, Epoch {epoch}, Cls Acc: {classical_test_errors[epoch, iters]:.3f}, L2M Acc: {test_errors[epoch, iters]:.3f}, CBL Acc: {balanced_test_errors[epoch, iters]:.3f}, SMOTE Acc: {smote_test_errors[epoch, iters]:.3f}, Alpha: {np.round(alpha.detach().cpu().numpy(),3)}")

# Step 2: Compute means and 95% confidence intervals
def compute_mean_and_CI(data):
    mean = np.mean(data, axis=0)
    std_error = np.std(data, axis=0) / np.sqrt(data.shape[0])  # Standard error of the mean
    stdev = np.std(data, axis=0)
    ci = 1.96 * std_error  # 95% CI for a normal distribution
    return mean, ci, stdev

A = classical_errors.T[:,:]; B = train_errors.T[:,:]; C = balanced_errors.T[:,:]; D = smote_errors.T[:,:];
mean_A, ci_A, stdev_A = compute_mean_and_CI(A)
mean_B, ci_B, stdev_B = compute_mean_and_CI(B)
mean_C, ci_C, stdev_C = compute_mean_and_CI(C)
mean_D, ci_D, stdev_D = compute_mean_and_CI(D)

X = classical_test_errors.T[:,:]; Z = test_errors.T[:,:]; Y = balanced_test_errors.T[:,:]; W = smote_test_errors.T[:,:];
mean_X, ci_X, stdev_X = compute_mean_and_CI(X)
mean_Z, ci_Z, stdev_Z = compute_mean_and_CI(Z)
mean_Y, ci_Y, stdev_Y = compute_mean_and_CI(Y)
mean_W, ci_W, stdev_W = compute_mean_and_CI(W)

print(f'Mean DNN t = E:       {mean_X[len(mean_X)-1].round(4)}, STDEV DNN t = E:      {stdev_X[len(stdev_X)-1].round(4)}')
print(f'Mean L2M t = E:       {mean_Z[len(mean_Z)-1].round(4)}, STDEV L2M t = E:      {stdev_Z[len(stdev_Z)-1].round(4)}')
print(f'Mean CBL t = E:       {mean_Y[len(mean_Y)-1].round(4)}, STDEV CBL t = E:      {stdev_Y[len(stdev_Y)-1].round(4)}')
print(f'Mean SMT t = E:       {mean_W[len(mean_W)-1].round(4)}, STDEV SMT t = E:      {stdev_W[len(stdev_W)-1].round(4)}\n')
print(f'Mean DNN t = 0.5E:       {mean_X[int(0.5*len(mean_X))-1].round(4)}, STDEV DNN t = 0.5E:      {stdev_X[int(0.5*len(stdev_X))-1].round(4)}')
print(f'Mean L2M t = 0.5E:       {mean_Z[int(0.5*len(mean_Z))-1].round(4)}, STDEV L2M t = 0.5E:      {stdev_Z[int(0.5*len(stdev_Z))-1].round(4)}')
print(f'Mean CBL t = 0.5E:       {mean_Y[int(0.5*len(mean_Y))-1].round(4)}, STDEV CBL t = 0.5E:      {stdev_Y[int(0.5*len(stdev_Y))-1].round(4)}')
print(f'Mean SMT t = 0.5E:       {mean_W[int(0.5*len(mean_W))-1].round(4)}, STDEV SMT t = 0.5E:      {stdev_W[int(0.5*len(mean_W))-1].round(4)}\n')
print(f'Mean DNN t = 0.25E:       {mean_X[int(0.25*len(mean_X))-1].round(4)}, STDEV DNN t = 0.25E:      {stdev_X[int(0.25*len(stdev_X))-1].round(4)}')
print(f'Mean L2M t = 0.25E:       {mean_Z[int(0.25*len(mean_Z))-1].round(4)}, STDEV L2M t = 0.25E:      {stdev_Z[int(0.25*len(stdev_Z))-1].round(4)}')
print(f'Mean CBL t = 0.25E:       {mean_Y[int(0.25*len(mean_Y))-1].round(4)}, STDEV CBL t = 0.25E:      {stdev_Y[int(0.25*len(stdev_Y))-1].round(4)}')
print(f'Mean SMT t = 0.25E:       {mean_W[int(0.25*len(mean_W))-1].round(4)}, STDEV SMT t = 0.25E:      {stdev_W[int(0.25*len(mean_W))-1].round(4)}\n')

# Step 3: Plot the results
epochs_all = list(range(1, X.shape[1] + 1))
plt.figure(figsize=(10, 6))

# Plotting for A
plt.fill_between(epochs_all, mean_A - ci_A, mean_A + ci_A, color='green', alpha=0.1, zorder=0)
plt.plot(epochs_all, mean_A, 'g-', label="Classical Train Accuracy", zorder=20)

# Plotting for X
plt.fill_between(epochs_all, mean_X - ci_X, mean_X + ci_X, color='blue', alpha=0.1, zorder=0)
plt.plot(epochs_all, mean_X, 'b-', label="Classical Test Accuracy", zorder=20)

# Plotting for C
plt.fill_between(epochs_all, mean_C - ci_C, mean_C + ci_C, color='pink', alpha=0.2, zorder=0)
plt.plot(epochs_all, mean_C, '-', color='pink', label="Focal Train Accuracy", zorder=20)

# Plotting for Y
plt.fill_between(epochs_all, mean_Y - ci_Y, mean_Y + ci_Y, color='purple', alpha=0.1, zorder=0)
plt.plot(epochs_all, mean_Y, '-', color='purple', label="Focal Test Accuracy", zorder=20)

# Plotting for D
plt.fill_between(epochs_all, mean_D - ci_D, mean_D + ci_D, color='skyblue', alpha=0.2, zorder=0)
plt.plot(epochs_all, mean_D, '-', color='skyblue', label="SMOTE Train Accuracy", zorder=20)

# Plotting for W
plt.fill_between(epochs_all, mean_W - ci_W, mean_W + ci_W, color='cyan', alpha=0.1, zorder=0)
plt.plot(epochs_all, mean_W, '-', color='cyan', label="SMOTE Test Accuracy", zorder=20)

# Plotting for B
plt.fill_between(epochs_all, mean_B - ci_B, mean_B + ci_B, color='orange', alpha=0.3, zorder=10)
plt.plot(epochs_all, mean_B, '-', color='orange', label="Learn2Mix Train Accuracy", zorder=30)

# Plotting for Z
plt.fill_between(epochs_all, mean_Z - ci_Z, mean_Z + ci_Z, color='red', alpha=0.3, zorder=10)
plt.plot(epochs_all, mean_Z, 'r-', label="Learn2Mix Test Accuracy", zorder=30)

# Additional plot settings
plt.xlabel("Number of Epochs", fontsize=20)
plt.ylabel("Accuracy (%)", fontsize=20)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend(prop={'size': 14},loc='lower right',framealpha=0.7).set_zorder(50)
plt.grid(True)
plt.savefig('test_imagenette.png', bbox_inches='tight')

np.savetxt('train_errors_imagenette.csv', train_errors, delimiter=',')
np.savetxt('classical_errors_imagenette.csv', classical_errors, delimiter=',')
np.savetxt('balanced_errors_imagenette.csv', balanced_errors, delimiter=',')
np.savetxt('smote_errors_imagenette.csv', smote_errors, delimiter=',')
np.savetxt('test_errors_imagenette.csv', test_errors, delimiter=',')
np.savetxt('classical_test_errors_imagenette.csv', classical_test_errors, delimiter=',')
np.savetxt('balanced_test_errors_imagenette.csv', balanced_test_errors, delimiter=',')
np.savetxt('smote_test_errors_imagenette.csv', smote_test_errors, delimiter=',')