import numpy as np
import os
import pickle
import random
import struct
import urllib.request
from collections import defaultdict

# download CIFAR-10 dataset
def load_cifar10_data(batch_size=64):
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict

    # CIFAR-10 dataset file path
    cifar10_dir = '/data/cifar-10-batches-py'
    # load training set
    train_data = []
    train_labels = []
    for i in range(1, 6):
        batch = unpickle(f'{cifar10_dir}/data_batch_{i}')
        train_data.append(batch[b'data'])
        train_labels.append(batch[b'labels'])
    X_train = np.concatenate(train_data)
    Y_train = np.concatenate(train_labels)
    
    # load test set
    test_batch = unpickle(f'{cifar10_dir}/test_batch')
    X_test = test_batch[b'data']
    Y_test = test_batch[b'labels']
    
    # normalize to [0, 1]
    X_train = X_train / 255.0
    X_test = X_test / 255.0
    
    # divide by batches
    def get_batch(X, Y, batch_size):
        n_samples = X.shape[0]
        for i in range(0, n_samples, batch_size):
            yield X[i:i + batch_size], Y[i:i + batch_size]
    
    return get_batch(X_train, Y_train, batch_size), get_batch(X_test, Y_test, batch_size)

# basic conv layer（NumPy）
def conv2d(X, W, b, stride=1, padding=0):
    """
    X: input image
    W: kernal
    b: bais
    stride: stride
    padding: padding
    """
    # padding
    X_padded = np.pad(X, ((0, 0), (padding, padding), (padding, padding)), mode='constant', constant_values=0)
    
    # input and kernal size
    N, H, W_in = X.shape
    F, C, HH, WW = W.shape
    H_out = (H - HH + 2 * padding) // stride + 1
    W_out = (W_in - WW + 2 * padding) // stride + 1
    
    # conv output
    out = np.zeros((N, F, H_out, W_out))
    for i in range(0, H_out):
        for j in range(0, W_out):
            x_slice = X_padded[:, i*stride:i*stride+HH, j*stride:j*stride+WW]
            out[:, :, i, j] = np.tensordot(x_slice, W, axes=((1, 2, 3), (1, 2, 3))) + b
    return out

# ReLU（NumPy）
def relu(X):
    return np.maximum(0, X)

# avg_pool（NumPy）
def avg_pool(X, pool_size=2, stride=2):
    N, C, H_in, W_in = X.shape
    H_out = (H_in - pool_size) // stride + 1
    W_out = (W_in - pool_size) // stride + 1
    
    out = np.zeros((N, C, H_out, W_out))
    for i in range(0, H_out):
        for j in range(0, W_out):
            x_slice = X[:, :, i*stride:i*stride+pool_size, j*stride:j*stride+pool_size]
            out[:, :, i, j] = np.mean(x_slice, axis=(2, 3))
    return out

# fully_connected（NumPy）
def fully_connected(X, W, b):
    return np.dot(X, W) + b

# BasicBlock（based on ResNet）
class BasicBlock:
    def __init__(self, in_channels, out_channels, stride=1):
        self.conv1 = (np.random.randn(out_channels, in_channels, 3, 3) * 0.1).astype(np.float32)
        self.conv2 = (np.random.randn(out_channels, out_channels, 3, 3) * 0.1).astype(np.float32)
        self.fc = (np.random.randn(out_channels * 8 * 8, 10) * 0.1).astype(np.float32)
        self.bn1 = np.ones((out_channels, 1, 1), dtype=np.float32)
        self.bn2 = np.ones((out_channels, 1, 1), dtype=np.float32)
    
    def forward(self, x):
        out = relu(conv2d(x, self.conv1, self.bn1))
        out = conv2d(out, self.conv2, self.bn2)
        return relu(out)

# ResNet
class ResNet:
    def __init__(self):
        self.block = BasicBlock(3, 16)
    
    def forward(self, x):
        out = self.block.forward(x)
        out = avg_pool(out)
        out = fully_connected(out.flatten(), self.block.fc, 0)
        return out

# train
def train(model, trainloader, epochs=10, learning_rate=0.001):
    for epoch in range(epochs):
        for X_batch, y_batch in trainloader:
            # Convert to NumPy array
            X_batch = np.array(X_batch)
            y_batch = np.array(y_batch)
            
            # Forward pass
            output = model.forward(X_batch)
            
            # Compute loss (cross-entropy loss)
            loss = np.mean(np.sum(-y_batch * np.log(output), axis=1))
            
            # Backward pass (simple gradient descent without actual backprop)
            output_gradient = (output - y_batch) / output.shape[0]
            
            # Update weights (simplified gradient descent)
            model.block.fc -= learning_rate * np.dot(output_gradient.T, X_batch.flatten())
            
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

# Main Function
def main():
    trainloader, testloader = load_cifar10_data(batch_size=64)
    model = ResNet()
    train(model, trainloader, epochs=10)

if __name__ == "__main__":
    main()
