import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset, Subset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
from collections import OrderedDict
from scipy.stats import sem
from tqdm import tqdm as tqdm
import torch.nn.functional as F
import torchvision.models as models
import argparse
import os
import sys
import wandb
from collections import defaultdict
from copy import deepcopy


parser = argparse.ArgumentParser(description='CIFAR10 with label noise and L2 regularization')
parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
parser.add_argument('--tp', type=float, default=0.7, help='TP rate of the prefilter')
parser.add_argument('--comment', type=str, default="", help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
comment = args.comment
P_NORM = 2.0
if eps == 0.0 and annp != 100.0:
    print(f"Eps value 0.0 only runs with annp=100.0; received eps={eps}, annp={annp}")
    sys.exit(0)

TP = args.tp
filename = f"cifar10oracle{TP}v2regL{P_NORM}v5_label_partialv2_{size_percent}{comment}_eps_{eps}_annp_{annp}_run_{run}.npz"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)


# Hyperparameters
num_classes = 10
NUM_EPOCHS = 30
DATA_DIR = "./data"
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class NoisyFilteredCIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None, download=True,
                 p_remove=0.1, TP_t=0.02, eps=0.3, ps_size=1.0, seed=42):

        TP = TP_t * p_remove
        assert 0 < ps_size <= 1.0
        assert 0 <= p_remove <= 1 and 0 <= eps <= 1
        assert 0 <= TP <= 1

        self.transform = transform
        self.num_classes = 10

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download)
        full_data = np.array(full_dataset.data)
        full_targets = np.array(full_dataset.targets)

        # Balanced per-class subsampling
        class_indices = defaultdict(list)
        for idx, label in enumerate(full_targets):
            class_indices[label].append(idx)

        sampled_indices = []
        n_per_class = int(len(full_targets) * ps_size // self.num_classes)
        for cls in range(self.num_classes):
            cls_idxs = class_indices[cls]
            selected = np.random.choice(cls_idxs, size=n_per_class, replace=False)
            sampled_indices.extend(selected)

        sampled_indices = np.array(sorted(sampled_indices))
        self.clean_labels = full_targets[sampled_indices]
        self.data = full_data[sampled_indices]

        N = len(self.clean_labels)

        # Compute counts
        N_noisy = int(eps * N)
        N_marked = int(p_remove * N)
        num_TP = min(min(int(TP * N), N_marked), N_noisy)
        FP = N_marked - num_TP
        FN = N_noisy - num_TP

        assert 0 <= FP <= N - N_noisy, f"Too many false positives {FP} {N} {N_noisy} {N-N_noisy}"
        assert 0 <= FN <= N_noisy, f"Too many false negatives {FN} {N} {N_noisy} {N-N_noisy}"

        all_indices = np.arange(N)
        noisy_indices = np.random.choice(all_indices, size=N_noisy, replace=False)
        noisy_set = set(noisy_indices)

        clean_indices = np.setdiff1d(all_indices, noisy_indices)

        tp_indices = np.random.choice(list(noisy_set), size=num_TP, replace=False)
        fp_indices = np.random.choice(clean_indices, size=FP, replace=False)

        marked_set = set(tp_indices).union(fp_indices)

        # Apply uniform label noise
        self.noisy_labels = self.clean_labels.copy()
        for idx in noisy_set:
            true_label = self.clean_labels[idx]
            choices = list(set(range(self.num_classes)) - {true_label})
            self.noisy_labels[idx] = np.random.choice(choices)

        # Remove marked
        kept_indices = sorted(set(all_indices) - marked_set)
        self.data = self.data[kept_indices]
        self.targets = self.noisy_labels[kept_indices]

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.targets[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

class NoisyCIFAR10(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label, idx  # include index

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = ResidualBlock(128)
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = ResidualBlock(512)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)       # [B, 64, 32, 32]
        x = self.conv2(x)       # [B, 128, 16, 16]
        x = self.res1(x)        # [B, 128, 16, 16]
        x = self.conv3(x)       # [B, 256, 8, 8]
        x = self.conv4(x)       # [B, 512, 4, 4]
        x = self.res2(x)        # [B, 512, 4, 4]
        x = self.pool(x)        # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        return self.fc(x)       # [B, num_classes]
    

transform = transforms.ToTensor()

full_train_dataset = NoisyFilteredCIFAR10(root='./data', train=True, transform=transform, download=True,
                 p_remove=(100.-annp)/100., TP_t=TP, eps=eps, ps_size=size_percent/100.)


train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Test set is loaded normally without contamination
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Define the CNN model
num_channels = 3
width = 32
num_classes = 10

num_models = 9
lambda_regs = torch.logspace(-4,-1.64,num_models-1).to(device)

all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)

all_best_val_test_acc = torch.zeros(num_models).to(device)
for i, lambda_reg in enumerate(lambda_regs):
    print(lambda_reg)
    # Initialize wandb
    wandb.init(project="cifar10_contaminated_weighted", config={
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "architecture": "custom CNN as provided",
        "contamination_prob": eps,
        "p_norm": P_NORM,
        "lambda_reg": lambda_reg
    })
    
    model = nn.Sequential(
        OrderedDict(
            [
                ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                ("relu0", nn.ReLU()),
                ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                ("relu1", nn.ReLU()),
                ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu2", nn.ReLU()),
                ("pool0", nn.MaxPool2d(3)),
                ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu3", nn.ReLU()),
                ("pool1", nn.AdaptiveAvgPool2d(1)),
                ("flatten", nn.Flatten()),
                ("linear", nn.Linear(4 * width, num_classes)),
            ]
        )
    ).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    best_test_acc = 0.0

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            ce_loss = criterion(outputs, targets)
            lp_reg = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
            lp_reg = lp_reg.pow(1.0 / P_NORM)
            loss = ce_loss + lambda_reg * lp_reg
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_all_preds = []
        val_all_labels = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Test evaluation at each epoch to track performance
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        test_all_preds = []
        test_all_labels = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
                test_all_preds.extend(predicted.cpu().numpy())
                test_all_labels.extend(targets.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        # Save best test accuracy based on minimal val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_test_acc = test_acc
        lp_norm = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
        lp_norm = lp_reg.pow(1.0 / P_NORM)
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "lp_norm": lp_norm,
        })
        all_train_losses[epoch-1, i] = train_loss
        all_train_acc[epoch-1, i] = train_acc
        all_val_losses[epoch-1, i] = val_loss
        all_val_acc[epoch-1, i] = val_acc
        all_test_losses[epoch-1, i] = test_loss
        all_test_acc[epoch-1, i] = test_acc

        print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
    wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
    all_best_val_test_acc[i] = best_test_acc
    print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

    wandb.finish()
np.savez(filename, all_train_losses.detach().cpu().numpy(), all_train_acc.detach().cpu().numpy(), all_val_losses.detach().cpu().numpy(), all_val_acc.detach().cpu().numpy(), all_test_losses.detach().cpu().numpy(), all_test_acc.detach().cpu().numpy(), all_best_val_test_acc.detach().cpu().numpy())#, all_test_f1.detach().cpu().numpy())
print("=" * 50)