import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset
from torch.utils.data import Subset, ConcatDataset, TensorDataset
from collections import defaultdict
import numpy as np
import os
import random
from PIL import Image
import csv


class Config:
    def __init__(self, mode, total_num,train_step,exper_name):
        self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        self.uniform_noise_scale = 0.00001 
        self.normal_noise_std = 0.00001     
        self.batch_size_train = 100
        self.batch_size_test = 5000
        self.batch_size_step = 10
        self.max_steps = train_step
        self.test_interval = 100
        self.lr = 1e-4
        self.data_root = './data'
        self.seed = 1  
        self.mode = mode
        self.total_num = total_num
        self.exper_name = exper_name

        if mode == 1:                          
            self.custom_data_ratio = 0.0  
            self.use_noisy_mnist = False  
            self.noise_type = 'both'     
            self.noise_ratio = 0       
            self.total_num = total_num//2
        elif mode == 2:                        
            self.custom_data_ratio = 0.0  
            self.use_noisy_mnist = True  
            self.noise_type = 'both'    
            self.noise_ratio = 0.5       
            self.total_num = total_num
        elif mode == 3:                          
            self.custom_data_ratio = 0.5  
            self.use_noisy_mnist = False 
            self.noise_type = 'both'     
            self.noise_ratio = 0       
            self.total_num = total_num
            self.custom_data_dir = 'generated_images/' + exper_name +  '/samples60000'
        elif mode == 4:                           
            self.custom_data_ratio = 0.5  
            self.use_noisy_mnist = False  
            self.noise_type = 'both'     
            self.noise_ratio = 0       
            self.total_num = total_num
            self.custom_data_dir = 'generated_images/' + exper_name + '/samples' + str(total_num//2)

        elif mode == 5:                          
            self.custom_data_ratio = 0  
            self.use_noisy_mnist = False  
            self.noise_type = 'both'     
            self.noise_ratio = 0       
            self.total_num = total_num
        else:
            raise ValueError("Invalid mode. Choose 1, 2 or 3.")


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_per_class=None, config=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        for label in range(10):
            class_dir = os.path.join(root_dir, str(label))
            if not os.path.exists(class_dir):
                continue

            image_files = [f for f in os.listdir(class_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
            if num_per_class is not None:
                np.random.seed(config.seed)  
                selected_files = np.random.choice(image_files, min(num_per_class, len(image_files)), replace=False)
            else:
                selected_files = image_files

            for file in selected_files:
                self.images.append(os.path.join(class_dir, file))
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('L')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)  


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10)
        )

    def forward(self, x):
        return self.fc(x)

def add_noise(images, noise_type='both', config=None):
    torch.manual_seed(config.seed)  
    noisy_images = images.clone()
    if noise_type in ('uniform', 'both'):
        uniform_noise = torch.rand_like(images) * 2 - 1  # [-1, 1]
        uniform_noise = uniform_noise * config.uniform_noise_scale
        noisy_images += uniform_noise
    if noise_type in ('normal', 'both'):
        normal_noise = torch.randn_like(images) * config.normal_noise_std
        noisy_images += normal_noise
    noisy_images = torch.clamp(noisy_images, 0.0, 1.0)
    return noisy_images


def create_clean_dataset(dataset, clean_num_per_class, normalize, config):
    np.random.seed(config.seed)
    
    all_labels = np.array([label for _, label in dataset])
    
    selected_indices = []
    for label in range(10):
        label_indices = np.where(all_labels == label)[0]
        selected = np.random.choice(label_indices, min(clean_num_per_class, len(label_indices)), replace=False)
        selected_indices.extend(selected)
    
    clean_imgs = []
    clean_labels = []
    for idx in selected_indices:
        img, label = dataset[idx]
        clean_imgs.append(normalize(img))
        clean_labels.append(torch.tensor(label, dtype=torch.long))
    
    return TensorDataset(torch.stack(clean_imgs), torch.stack(clean_labels))


def create_noisy_dataset(dataset, selected_clean_indices, noisy_num_per_class, normalize, noise_type, add_noise_fn, config):
    np.random.seed(config.seed)
    selected_clean_set = set(selected_clean_indices)
    
    all_data = [(idx, label) for idx, (_, label) in enumerate(dataset) if idx not in selected_clean_set]
    all_indices, all_labels = zip(*all_data) if all_data else ([], [])
    all_labels = np.array(all_labels)
    
    selected_indices = []
    for label in range(10):
        label_indices = np.where(all_labels == label)[0]
        if len(label_indices) > 0:
            selected = np.random.choice(label_indices, min(noisy_num_per_class, len(label_indices)), replace=False)
            selected_indices.extend([all_indices[i] for i in selected])
    
    noisy_imgs = []
    noisy_labels = []
    for idx in selected_indices:
        img, label = dataset[idx]
        noisy_img = add_noise_fn(img, noise_type, config)
        noisy_imgs.append(normalize(noisy_img))
        noisy_labels.append(torch.tensor(label, dtype=torch.long))
    
    return TensorDataset(torch.stack(noisy_imgs), torch.stack(noisy_labels))


def load_datasets(config):
    base_transform = transforms.ToTensor()
    normalize = transforms.Normalize((0.1307,), (0.3081,))
    full_transform = transforms.Compose([base_transform, normalize])

    if config.exper_name == "mnist":
        full_train_dataset = datasets.MNIST(root='./data', train=True, transform=base_transform, download=True)
        test_dataset = datasets.MNIST(root=config.data_root, train=False, transform=full_transform)
    if config.exper_name =="fashion_mnist":
        full_train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=base_transform, download=True)
        test_dataset = datasets.FashionMNIST(root=config.data_root, train=False, transform=full_transform)

    custom_num = int(config.total_num * config.custom_data_ratio)
    remaining_num = config.total_num - custom_num
    noisy_num = int(remaining_num * config.noise_ratio) if config.use_noisy_mnist else 0
    clean_num = remaining_num - noisy_num

    per_class_clean = clean_num // 10
    per_class_noisy = noisy_num // 10
    per_class_custom = custom_num // 10

    all_labels = np.array([label for _, label in full_train_dataset])

    clean_indices = []
    for label in range(10):
        label_indices = np.where(all_labels == label)[0]
        selected = np.random.choice(label_indices, min(per_class_clean, len(label_indices)), replace=False)
        clean_indices.extend(selected)
    
    clean_subset = Subset(full_train_dataset, clean_indices)
    clean_imgs = torch.stack([normalize(img) for img, _ in clean_subset])
    clean_labels = torch.stack([torch.tensor(label, dtype=torch.long) for _, label in clean_subset])
    clean_dataset = TensorDataset(clean_imgs, clean_labels)

    noisy_dataset = None
    if noisy_num > 0:
        all_indices = set(range(len(full_train_dataset)))
        remaining_indices = list(all_indices - set(clean_indices))
        remaining_labels = all_labels[remaining_indices]
        
        noisy_indices = []
        for label in range(10):
            label_indices = np.where(remaining_labels == label)[0]
            if len(label_indices) > 0:
                selected = np.random.choice(label_indices, min(per_class_noisy, len(label_indices)), replace=False)
                noisy_indices.extend([remaining_indices[i] for i in selected])
        
        noisy_subset = Subset(full_train_dataset, noisy_indices)
        noisy_imgs = torch.stack([normalize(add_noise(img, config.noise_type, config)) for img, _ in noisy_subset])
        noisy_labels = torch.stack([torch.tensor(label, dtype=torch.long) for _, label in noisy_subset])
        noisy_dataset = TensorDataset(noisy_imgs, noisy_labels)

    custom_dataset = None
    if custom_num > 0:
        custom_dataset = CustomImageDataset(
            root_dir=config.custom_data_dir,
            transform=full_transform,
            num_per_class=per_class_custom,
            config=config
        )

    dataset_list = [clean_dataset]
    if noisy_dataset:
        dataset_list.append(noisy_dataset)
    if custom_dataset:
        dataset_list.append(custom_dataset)

    train_dataset = ConcatDataset(dataset_list)

   
    return train_dataset, test_dataset


def test(model, test_loader, device, criterion, step):
    model.eval()
    correct = 0
    test_loss = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

    test_loss /= len(test_loader)
    acc = 100. * correct / len(test_loader.dataset)
    #print(f'[Step {step}] Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({acc:.2f}%)')
    return test_loss, acc

def train(config):
    set_seed(config.seed)  
    
    train_dataset, test_dataset = load_datasets(config)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size_train, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size_test, shuffle=False)
    step_loader = DataLoader(train_dataset, batch_size=config.batch_size_step, shuffle=True)

    model = Net().to(config.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    step = 0
    iterator = iter(step_loader)
    last_test_result = None

    while step < config.max_steps:
        try:
            data, target = next(iterator)
        except StopIteration:
            iterator = iter(step_loader)
            data, target = next(iterator)

        model.train()
        data, target = data.to(config.device), target.to(config.device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if (step + 1) % config.test_interval == 0:
            test_loss, acc = test(model, test_loader, config.device, criterion, step + 1)
            last_test_result = (test_loss, acc)
        step += 1
    
    return last_test_result[1] if last_test_result else None  


def main():

    exper_name = 'fashion_mnist'   #mnist
    train_step = 10000
    csv_name = 'classification/result/'  + exper_name +'_train_step=' + str(train_step) + '.csv'
    with open(csv_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['total_num', 'mode1_acc', 'mode2_acc', 'mode3_acc', 'mode4_acc', 'mode5_acc'])

    num_list = [50,100,200,500,1000,2000,5000,10000]

    for num in num_list:
        total_num = num*2
        mode_results = {}

        for mode in range(1, 6):
            config = Config(mode=mode, total_num=total_num,train_step= train_step,exper_name= exper_name)
            acc = train(config)
            mode_results[mode] = acc
        
        with open(csv_name, 'a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                total_num,
                mode_results.get(1, 'N/A'),
                mode_results.get(2, 'N/A'),
                mode_results.get(3, 'N/A'),
                mode_results.get(4, 'N/A'),
                mode_results.get(5, 'N/A')
            ])

if __name__ == '__main__':
    main()






