import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset, Subset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
# import matplotlib.pyplot as plt
from collections import OrderedDict
from scipy.stats import sem
from tqdm import tqdm as tqdm
import torch.nn.functional as F
import torchvision.models as models
import argparse
import os
import sys
import wandb
from collections import defaultdict

parser = argparse.ArgumentParser(description='CIFAR10 with label noise and L2 regularization')
parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
parser.add_argument('--comment', type=str, default="", help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
comment = args.comment
P_NORM = 2.0
if eps == 0.0 and annp != 100.0:
    print(f"Eps value 0.0 only runs with annp=100.0; received eps={eps}, annp={annp}")
    sys.exit(0)

filename = f"cifar10regL{P_NORM}v5_label_partialv2_{size_percent}{comment}_eps_{eps}_annp_{annp}_run_{run}.npz"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)


# Hyperparameters
num_classes = 10
NUM_EPOCHS = 40
DATA_DIR = "./data"
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

class NoisyCIFAR10(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label, idx  # include index

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = ResidualBlock(128)
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = ResidualBlock(512)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)       # [B, 64, 32, 32]
        x = self.conv2(x)       # [B, 128, 16, 16]
        x = self.res1(x)        # [B, 128, 16, 16]
        x = self.conv3(x)       # [B, 256, 8, 8]
        x = self.conv4(x)       # [B, 512, 4, 4]
        x = self.res2(x)        # [B, 512, 4, 4]
        x = self.pool(x)        # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        return self.fc(x)       # [B, num_classes]
    
transform = transforms.ToTensor()
trainset_clean = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset_clean = BalancedCIFAR10(train=True, download=True, transform=None, percentage=size_percent / 100)
original_labels = np.array(trainset_clean.targets)

# Introduce uniform label noise
noisy_labels = original_labels.copy()
num_noisy = int(eps * len(original_labels))
noisy_indices = np.random.choice(len(original_labels), size=num_noisy, replace=False)

for idx in noisy_indices:
    true_label = original_labels[idx]
    choices = list(range(num_classes))
    choices.remove(true_label)
    noisy_labels[idx] = np.random.choice(choices)



noisy_trainset = NoisyCIFAR10(trainset_clean.data, noisy_labels, transform=transform)
def train_and_prefilter(noisy_trainset, percentile):
    batch_size = 128
    epochs = 7
    learning_rate = 1e-3
    num_classes = 10
    num_channels = 3
    width = 32
    trainloader = DataLoader(noisy_trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Model: ResNet-9
    model = ResNet9(in_channels=3, num_classes=10)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Optimizer and Loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    criterion_no_reduce = nn.CrossEntropyLoss(reduction='none')

    # Training
    model.train()
    for epoch in range(epochs):
        for inputs, labels, _ in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Compute individual training losses
        model.eval()
        losses = np.zeros(len(noisy_trainset))
        

        with torch.no_grad():
            for inputs, labels, indices in DataLoader(noisy_trainset, batch_size=batch_size, shuffle=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                batch_losses = criterion_no_reduce(outputs, labels)
                losses[indices.numpy()] = batch_losses.cpu().numpy()

        # Analysis
        threshold = np.percentile(losses, percentile)
        selected_indices = np.where(losses <= threshold)[0]
        contaminated_losses = losses[noisy_indices]
        subset_dataset = Subset(noisy_trainset, selected_indices)
        fraction_above_threshold = np.mean(contaminated_losses > threshold)
        print(f"Subset fraction: {len(subset_dataset)} points out of {len(noisy_trainset)} ({100 * len(subset_dataset) / len(noisy_trainset):.2f}%)")
        print(f"Epoch {epoch}")
        print(f"70th percentile of all losses: {threshold:.4f}")
        print(f"Fraction of contaminated points with loss above threshold: {fraction_above_threshold:.4f}")
    
    return subset_dataset

if annp != 100.:
    full_train_dataset = train_and_prefilter(noisy_trainset=noisy_trainset, percentile=annp)
else:
    full_train_dataset = noisy_trainset

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Test set is loaded normally without contamination
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Define the CNN model
num_channels = 3
width = 32
num_classes = 10

num_models = 9
lambda_regs = torch.logspace(-5,-0,steps=num_models).to(device)

lambda_regs = torch.logspace(-4,-1.64,num_models-1).to(device)
all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)

all_best_val_test_acc = torch.zeros(num_models).to(device)
for i, lambda_reg in enumerate(lambda_regs):
    print(lambda_reg)
    # Initialize wandb
    wandb.init(project="cifar10_contaminated_weighted", config={
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "architecture": "custom CNN as provided",
        "contamination_prob": eps,
        "p_norm": P_NORM,
        "lambda_reg": lambda_reg
    })
    
    model = nn.Sequential(
        OrderedDict(
            [
                ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                ("relu0", nn.ReLU()),
                ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                ("relu1", nn.ReLU()),
                ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu2", nn.ReLU()),
                ("pool0", nn.MaxPool2d(3)),
                ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu3", nn.ReLU()),
                ("pool1", nn.AdaptiveAvgPool2d(1)),
                ("flatten", nn.Flatten()),
                ("linear", nn.Linear(4 * width, num_classes)),
            ]
        )
    ).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    best_test_acc = 0.0

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets, _ in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            ce_loss = criterion(outputs, targets)
            lp_reg = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
            lp_reg = lp_reg.pow(1.0 / P_NORM)
            loss = ce_loss + lambda_reg * lp_reg
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_all_preds = []
        val_all_labels = []
        with torch.no_grad():
            for inputs, targets, _ in val_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Test evaluation at each epoch to track performance
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        test_all_preds = []
        test_all_labels = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
                test_all_preds.extend(predicted.cpu().numpy())
                test_all_labels.extend(targets.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        # Save best test accuracy based on minimal val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_test_acc = test_acc
        lp_norm = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
        lp_norm = lp_reg.pow(1.0 / P_NORM)
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "lp_norm": lp_norm,
        })
        all_train_losses[epoch-1, i] = train_loss
        all_train_acc[epoch-1, i] = train_acc
        all_val_losses[epoch-1, i] = val_loss
        all_val_acc[epoch-1, i] = val_acc
        all_test_losses[epoch-1, i] = test_loss
        all_test_acc[epoch-1, i] = test_acc

        print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
    wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
    all_best_val_test_acc[i] = best_test_acc
    print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

    wandb.finish()
np.savez(filename, all_train_losses.detach().cpu().numpy(), all_train_acc.detach().cpu().numpy(), all_val_losses.detach().cpu().numpy(), all_val_acc.detach().cpu().numpy(), all_test_losses.detach().cpu().numpy(), all_test_acc.detach().cpu().numpy(), all_best_val_test_acc.detach().cpu().numpy())#, all_test_f1.detach().cpu().numpy())
print("=" * 50)