import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import random
import scipy.stats as st
from tqdm import tqdm
from multiprocessing import Pool
import os
from multiprocessing import get_context
import pickle
import torch.nn.functional as F
import argparse
import torch.nn.init as init

from data import gen_data
from model_linear import SixLayerNeuralNet, FourLayerNeuralNet, TwoLayerNeuralNet, SixLayerNeuralNet_small, FourLayerNeuralNet_small, TwoLayerNeuralNet_small

test_data_num = 2000
epsilon = 0.01 # 訓練損失がこれ以下になったら大域解とみなす
scaling_student = 1.0
scaling_teacher = 1.0
device = 0
criterion = nn.MSELoss()

@torch.no_grad()
def gc(model_name, data_num_list, true_network, seed):
    torch.manual_seed(seed)
    np.random.seed(seed=seed)
    random.seed(seed)

    if model_name == "six_layer":
        net = SixLayerNeuralNet()
    if model_name == "four_layer":
        net = FourLayerNeuralNet()
    if model_name == "two_layer":
        net = TwoLayerNeuralNet()

    np_rng = np.random.default_rng(seed)

    train_dataset = gen_data(2000, true_network, np_rng)

    print(train_dataset[:10])

    test_dataset = gen_data(test_data_num, true_network, np_rng)

    print(test_dataset[:10])
    testloader = DataLoader(test_dataset,
                            batch_size=test_data_num,
                            shuffle=False,
                            num_workers=0)

    for data_num in data_num_list:
        trainloader = DataLoader(train_dataset[:data_num],
                        batch_size=data_num,
                        shuffle=False,
                        num_workers=0)
        
        loss = torch.tensor(100).to(device)
        previous_loss = 100

        # Train
        while loss >= epsilon:
            for f in net.parameters():
                f.data = torch.as_tensor(np_rng.random(f.shape) * 2 * scaling_student - scaling_student, dtype=torch.float32).to(device)
                f.requires_grad = False
            for m in net.modules():
                if isinstance(m, nn.Linear):
                    init.xavier_uniform_(m.weight, gain=1.0)  # uniform
                    if m.bias is not None:
                        fan_in, fan_out = m.weight.size(1), m.weight.size(0)
                        bound = math.sqrt(6.0 / (fan_in + fan_out))
                        nn.init.uniform_(m.bias, -bound, bound)
            net = net.to(device)

            # for index, f in enumerate(net.parameters()):
            #     if index == 0:
            #         print(f.data)
                
            loss = torch.tensor(0.0).to(device)
            for i, data in enumerate(trainloader, 0):
                xs, ys = data
                xs = xs.to(device)
                ys = ys.to(device)

                output = net(xs)
                loss += criterion(output, ys)

            # if loss <= previous_loss:
            #     print(loss)
            #     previous_loss = loss.item()
            # if loss <= epsilon:
            #     print(loss)

        # Evaluation
        test_loss = torch.tensor(0.0).to(device)
        for i, data in enumerate(testloader, 0):
            xs, ys = data
            xs = xs.to(device)
            ys = ys.to(device)
            output = net(xs)
            test_loss += criterion(output, ys)

        print(f"found, data_num={data_num}, seed={seed}, train_loss={loss}, test_loss={test_loss}")
        # for i, f in enumerate(net.parameters()):
        #     if i == 0:
        #         print(f.data)
        checkpoint = {
            "model_state": net.state_dict(),
            "train_loss": loss.item(),
            "test_loss": test_loss.item()
        }
        folder = f"easy_experiment_result/gc_linear/{model_name}/{data_num}"
        os.makedirs(folder, exist_ok=True)
        torch.save(checkpoint, f"{folder}/model_weights_{seed}.pth")
           
    return


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--sample_num', type=int)
    parser.add_argument('--data_num_list', required=True, nargs="*", type=int, help='a list of int variables')

    args = parser.parse_args()
    model_name = args.model_name
    sample_num = args.sample_num
    data_num_list = args.data_num_list

    torch.manual_seed(0)
    np.random.seed(seed=0)
    random.seed(0)
    true_network = TwoLayerNeuralNet_small()
    
    # np_rng = np.random.default_rng(3000)
    # for f in true_network.parameters():
    #     f.data = torch.as_tensor(np_rng.random(f.shape) * 2 * scaling_teacher - scaling_teacher, dtype=torch.float32)
    #     f.requires_grad = False

    for m in true_network.modules():
        if isinstance(m, nn.Linear):
            init.xavier_uniform_(m.weight, gain=1.0)  # uniform
            if m.bias is not None:
                init.zeros_(m.bias) 

    # for f in true_network.parameters():
    #     print(f.data)
    
    checkpoint = {
        "model_state": true_network.state_dict()
    }

    folder = f"easy_experiment_result/gc_linear/{model_name}/true_network"

    os.makedirs(folder, exist_ok=True)
    torch.save(checkpoint, f"{folder}/model_weights.pth")

    for seed in range(sample_num):
        gc(model_name, data_num_list, true_network, seed)


if __name__ == "__main__":
    main()