import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Subset
import random
import os
import argparse

from data import dataset
from model import FourLayerNeuralNet, TwoLayerNeuralNet, LeNet, ThreeLayerNeuralNet

epsilon = 0.01
batch_size = 1024

device = 0

criterion = nn.CrossEntropyLoss()

train_dataset, valid_dataset = dataset()

validloader = DataLoader(valid_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=0)

def sgd(model_name, data_num_list, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    train_dataset_list = list(train_dataset)
    random.shuffle(train_dataset_list)
    for data_num in data_num_list:
        print(f"Data_num: {data_num}", flush=True)
        if model_name == "lenet":
            net = LeNet(1)
        if model_name == "two_layer":
            net = TwoLayerNeuralNet()
        if model_name == "three_layer":
            net = ThreeLayerNeuralNet()
        if model_name == "four_layer":
            net = FourLayerNeuralNet()

        net = net.to(device)
        net.train()
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        trainloader = DataLoader(train_dataset_list[:data_num],
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0)

        loss = torch.tensor(100).to(device)
        err_rate = 100
        total_loss = 100

        # Training
        while total_loss >= epsilon:
            correct = 0
            total = 0
            preds = torch.tensor([]).float().to(device)
            trues = torch.tensor([]).int().to(device)
            for i, data in enumerate(trainloader, 0):
                xs, ys = data
                xs = xs.to(torch.float32)
                xs = xs.to(device)
                ys = ys.to(torch.int64)
                ys = ys.to(device)
                output = net(xs)
                loss = criterion(output, ys)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                _, predicted = torch.max(output.data, 1)
                correct += (predicted == ys).sum().item()
                total += ys.size(0)
                preds = torch.cat((preds, output))
                trues = torch.cat((trues, ys))
            
            total_loss = criterion(preds, trues)
            train_err_rate = 100 * (1 - correct / total)

    
        # Evaluation
        net.eval()
        with torch.no_grad():
            test_loss = torch.tensor(0.0).to(device)
            test_preds = torch.tensor([]).float().to(device)
            test_trues = torch.tensor([]).int().to(device)
            correct = 0
            total = 0
            for i, data in enumerate(validloader, 0):
                xs, ys = data
                xs = xs.to(device)
                ys = ys.to(torch.int64)
                ys = ys.to(device)
                output = net(xs)
                _, predicted = torch.max(output.data, 1)
                correct += (predicted == ys).sum().item()
                total += ys.size(0)
                test_preds = torch.cat((test_preds, output))
                test_trues = torch.cat((test_trues, ys))

            test_loss = criterion(test_preds, test_trues)
            err_rate = 100 * (1 - correct / total)
        print(f"found, data_num={data_num}, seed={seed}, train_loss={total_loss}, train_err_rate={train_err_rate}, test_loss={test_loss}, test_err_rate={err_rate}")
        checkpoint = {
            "model_state": net.state_dict(),
            "train_loss": total_loss,
            "train_error": train_err_rate,
            "test_loss": test_loss,
            "test_error": err_rate
        }
        os.makedirs(f"large_experiment_result/sgd/{model_name}/{data_num}", exist_ok=True)
        torch.save(checkpoint, f"large_experiment_result/sgd/{model_name}/{data_num}/model_weights_{seed}.pth")
   
    return


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--sample_num', type=int)
    parser.add_argument('--data_num_list', required=True, nargs="*", type=int, help='a list of int variables')

    args = parser.parse_args()
    model_name = args.model_name
    sample_num = args.sample_num
    data_num_list = args.data_num_list

    seed_list = range(sample_num)
    for seed in seed_list:
        sgd(model_name, data_num_list, seed)


if __name__ == "__main__":
    main()
