import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import wandb
from collections import Counter
import json
import random


def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 12

class ImbalancedCIFAR100Dataset(Dataset):
    """Custom dataset for imbalanced CIFAR-100 using exponential distribution sampling"""
    
    def __init__(self, original_dataset, indices, transform=None):
        self.dataset = original_dataset
        self.indices = indices
        self.transform = transform
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        image, label = self.dataset[real_idx]
        if self.transform:
            image = self.transform(image)
        return image, label

def exponential_sampling(dataset, tail_index_a, val_split=0.1):
    """
    Sample from CIFAR-100 dataset according to exponential distribution
    p(x) = exp(-a * x), where a is the tail index
    """
    
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    num_classes = 100
    total_samples = np.max(list(Counter(labels).values()))
    
    
    
    class_samples = []
    for i in np.arange(1, num_classes + 1):
        p_i = np.exp(-tail_index_a * i)
        class_samples.append(p_i)
    
    
    class_samples = np.array(class_samples)
    
    
    
    samples_per_class = (class_samples * total_samples).astype(int)
    
    
    class_indices = {i: np.where(labels == i)[0] for i in range(num_classes)}
    for i in range(num_classes):
        available = len(class_indices[i])
        if samples_per_class[i] > available:
            samples_per_class[i] = available
    
    
    sampled_indices = []
    for i in range(num_classes):
        if samples_per_class[i] > 0:
            selected = np.random.choice(class_indices[i], 
                                      size=samples_per_class[i], 
                                      replace=False)
            sampled_indices.extend(selected.tolist())
    
    
    np.random.shuffle(sampled_indices)
    
    
    val_size = int(len(sampled_indices) * val_split)
    val_indices = sampled_indices[:val_size]
    train_indices = sampled_indices[val_size:]
    
    
    train_labels = [labels[idx] for idx in train_indices]
    val_labels = [labels[idx] for idx in val_indices]
    
    print(f"\nDataset statistics for a={tail_index_a}:")
    print(f"Total train samples: {len(train_indices)}")
    print(f"Total val samples: {len(val_indices)}")
    print(f"Train class distribution (top 10): {Counter(train_labels).most_common(10)}")
    print(f"Train class distribution (bottom 10): {Counter(train_labels).most_common()[-10:]}")
    
    return train_indices, val_indices, samples_per_class

def get_cifar100_transforms():
    """Get normalization transforms for CIFAR-100"""
    
    mean = [0.5071, 0.4867, 0.4408]
    std = [0.2675, 0.2565, 0.2761]
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    
    return transform

def create_model(num_classes=100, pretrained=False):
    """Create ResNet50 model for CIFAR-100"""
    model = resnet50(pretrained=pretrained)
    
    
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    
    
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    return model

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    avg_loss = running_loss / len(train_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    avg_loss = running_loss / len(val_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy, all_preds, all_targets

def calculate_per_class_accuracy(y_true, y_pred, num_classes=100):
    """Calculate per-class accuracy"""
    per_class_acc = {}
    for class_idx in range(num_classes):
        class_mask = np.array(y_true) == class_idx
        if class_mask.sum() > 0:
            class_correct = (np.array(y_pred)[class_mask] == class_idx).sum()
            per_class_acc[class_idx] = class_correct / class_mask.sum() * 100
        else:
            per_class_acc[class_idx] = 0.0
    return per_class_acc

def plot_confusion_matrix(y_true, y_pred, tail_index, save_path):
    """Plot and save confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(20, 16))
    sns.heatmap(cm, cmap='Blues', cbar=True, square=True)
    plt.title(f'Confusion Matrix - CIFAR-100 with Exponential a={tail_index}')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_training_curves(train_losses, val_losses, train_accs, val_accs, tail_index, save_path):
    """Plot training curves"""
    epochs = range(1, len(train_losses) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    
    ax1.plot(epochs, train_losses, 'b-', label='Train Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'Loss Curves - CIFAR-100 with Exponential a={tail_index}')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    
    ax2.plot(epochs, train_accs, 'b-', label='Train Accuracy')
    ax2.plot(epochs, val_accs, 'r-', label='Val Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title(f'Accuracy Curves - CIFAR-100 with Exponential a={tail_index}')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_class_distribution(samples_per_class, tail_index, save_path):
    """Plot class distribution"""
    plt.figure(figsize=(15, 6))
    classes = range(1, 101)
    plt.bar(classes, samples_per_class, color='skyblue', edgecolor='navy', alpha=0.7)
    plt.xlabel('Class Index')
    plt.ylabel('Number of Samples')
    plt.title(f'Class Distribution - CIFAR-100 with Exponential a={tail_index}')
    plt.grid(True, axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def train_model_for_tail_index(tail_index_a, device, base_dir='./experiments'):
    """Train ResNet50 on imbalanced CIFAR-100 for a specific tail index"""
    
    
    exp_name = f'cifar100_exponential_a{tail_index_a}'
    exp_dir = os.path.join(base_dir, exp_name)
    os.makedirs(exp_dir, exist_ok=True)
    
    
    wandb.init(
        project="cifar100_imbalanced",
        name=exp_name,
        config={
            "tail_index_a": tail_index_a,
            "model": "resnet50",
            "dataset": "cifar100",
            "epochs": 200,
            "batch_size": 128,
            "learning_rate": 0.1,
            "momentum": 0.9,
            "weight_decay": 5e-4,
            "lr_schedule": "cosine"
        }
    )
    
    
    transform = get_cifar100_transforms()
    
    train_dataset = torchvision.datasets.CIFAR100(
        root='../data_', train=True, download=False, transform=None
    )
    
    test_dataset = torchvision.datasets.CIFAR100(
        root='../data_', train=False, download=False, transform=transform
    )
    
    
    train_indices, val_indices, samples_per_class = exponential_sampling(
        train_dataset, tail_index_a, val_split=0.1
    )
    
    
    plot_class_distribution(
        samples_per_class, tail_index_a, 
        os.path.join(exp_dir, f'class_distribution_a{tail_index_a}.png')
    )
    
    
    train_dataset_imbalanced = ImbalancedCIFAR100Dataset(
        train_dataset, train_indices, transform=transform
    )
    val_dataset_imbalanced = ImbalancedCIFAR100Dataset(
        train_dataset, val_indices, transform=transform
    )
    
    
    train_loader = DataLoader(
        train_dataset_imbalanced, batch_size=128, shuffle=True, 
        num_workers=4, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset_imbalanced, batch_size=128, shuffle=False, 
        num_workers=4, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, 
        num_workers=4, pin_memory=True
    )
    
    
    model = create_model(num_classes=100).to(device)
    
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4
    )
    
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    
    best_val_acc = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(200):
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        
        val_loss, val_acc, val_preds, val_targets = validate(model, val_loader, criterion, device)
        
        
        scheduler.step()
        
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_acc": train_acc,
            "val_acc": val_acc,
            "learning_rate": scheduler.get_last_lr()[0]
        })
        
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'train_acc': train_acc,
                'val_loss': val_loss,
                'train_loss': train_loss,
                'tail_index_a': tail_index_a
            }
            torch.save(
                checkpoint, 
                os.path.join(exp_dir, f'best_model_a{tail_index_a}.pth')
            )
        
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/200] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    
    checkpoint = {
        'epoch': 200,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_acc': val_acc,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'train_loss': train_loss,
        'tail_index_a': tail_index_a
    }
    torch.save(
        checkpoint, 
        os.path.join(exp_dir, f'last_model_a{tail_index_a}.pth')
    )
    
    
    plot_training_curves(
        train_losses, val_losses, train_accs, val_accs, tail_index_a,
        os.path.join(exp_dir, f'training_curves_a{tail_index_a}.png')
    )
    
    
    checkpoint = torch.load(os.path.join(exp_dir, f'best_model_a{tail_index_a}.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    
    val_loss, val_acc, val_preds, val_targets = validate(model, val_loader, criterion, device)
    per_class_acc_val = calculate_per_class_accuracy(val_targets, val_preds)
    
    
    test_loss, test_acc, test_preds, test_targets = validate(model, test_loader, criterion, device)
    per_class_acc_test = calculate_per_class_accuracy(test_targets, test_preds)
    
    
    plot_confusion_matrix(
        val_targets, val_preds, tail_index_a,
        os.path.join(exp_dir, f'confusion_matrix_val_a{tail_index_a}.png')
    )
    plot_confusion_matrix(
        test_targets, test_preds, tail_index_a,
        os.path.join(exp_dir, f'confusion_matrix_test_a{tail_index_a}.png')
    )
    
    
    wandb.log({
        "best_val_acc": best_val_acc,
        "final_test_acc": test_acc,
        "final_test_loss": test_loss
    })
    
    
    results = {
        "tail_index_a": tail_index_a,
        "best_val_acc": best_val_acc,
        "final_test_acc": test_acc,
        "final_test_loss": test_loss,
        "per_class_acc_val": per_class_acc_val,
        "per_class_acc_test": per_class_acc_test,
        "samples_per_class": samples_per_class.tolist()
    }
    
    with open(os.path.join(exp_dir, f'results_a{tail_index_a}.json'), 'w') as f:
        json.dump(results, f, indent=4)
    
    wandb.finish()
    
    return results

def main():
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    
    base_dir = './'
    os.makedirs(base_dir, exist_ok=True)
    
    
    all_results = {}
    for tail_index_a in [0.01, 0.02, 0.03]:  
        print(f"\n{'='*60}")
        print(f"Training model with tail index a = {tail_index_a}")
        print(f"{'='*60}")
        
        results = train_model_for_tail_index(tail_index_a, device, base_dir)
        all_results[f'a{tail_index_a}'] = results
    
    
    with open(os.path.join(base_dir, 'all_results_summary.json'), 'w') as f:
        json.dump(all_results, f, indent=4)
    
    
    plot_summary_results(all_results, base_dir)
    
    print("\nTraining completed for all tail indices!")

def plot_summary_results(all_results, base_dir):
    """Create summary plots comparing all tail indices"""
    tail_indices = []
    test_accuracies = []
    
    for a in [0.01, 0.02, 0.03]:
        key = f'a{a}'
        if key in all_results:
            tail_indices.append(a)
            test_accuracies.append(all_results[key]['final_test_acc'])
    
    
    plt.figure(figsize=(10, 6))
    plt.plot(tail_indices, test_accuracies, 'bo-', markersize=10, linewidth=2)
    plt.xlabel('Tail Index (a)')
    plt.ylabel('Test Accuracy (%)')
    plt.title('Test Accuracy vs Imbalance Level (Exponential Tail Index)')
    plt.grid(True, alpha=0.3)
    plt.xticks(tail_indices)
    
    
    for i, (a, acc) in enumerate(zip(tail_indices, test_accuracies)):
        plt.text(a, acc + 0.5, f'{acc:.2f}%', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, 'test_accuracy_vs_tail_index.png'), dpi=300, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    main()

















