import argparse
import copy
import os
os.environ["OMP_NUM_THREADS"] = "1"
import pickle
import random 
import datetime
import time
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
# from entropy_estimators import continuous


################################################
class MeanFieldNN(nn.Module):
    def __init__(self, d, m, R, init_a_std=None, init_b_std=None, init_c_std=None):
        super().__init__()
        self.d = d
        self.m = m
        self.R = R
        self.init_a_std = init_a_std
        self.init_b_std = init_b_std
        self.init_c_std = init_c_std

        self.fc1 = nn.Linear(d, m)
        self.fc2 = nn.Linear(m, 1, bias=False)
        self.fc3 = nn.Parameter(torch.zeros(m), requires_grad=True)

        #initialization
        nn.init.normal_(self.fc1.weight, std=(1/d**0.5 if init_a_std is None else init_a_std)) # a_{ij} is initialized as $a_{ij} \sim N(0,d) (if init_a_std is None) / \sim N(0,init_a_std) (else)$
        nn.init.normal_(self.fc1.bias, std=(1/d**0.5 if init_b_std is None else init_b_std)) # b_{j} is initialized as $b_{j} \sim N(0,d) (if init_b_std is None) / \sim N(0,init_b_std) (else)$
        nn.init.normal_(self.fc3, std=(1. if init_c_std is None else init_c_std))
        nn.init.constant_(self.fc2.weight, R/m)
        self.fc2.weight.requires_grad_(False)

    def forward(self, x):
        x = torch.tanh(self.fc1(x)) + torch.tanh(self.fc3)
        x = self.fc2(x)
        return x

################################################
class DataLoader():
    def __init__(self, d, n, N=None): # if N is specified as some integer, finite $N$ data is first generated and $n$ data is sampled from that. Else, at every step $n$ data is newly generated.
        self.d = d
        self.n = n
        self.N = N
        if N is None:
            self.mode = "stochastic"
        else:
            self.mode = "finitesample"
            X = (torch.randint(2, (N, d)) - 0.5)*2/d**0.5
            Y = (X[:, 0]*X[:, 1]*X[:, 2]*X[:, 3]*(d**2)).reshape(N,1) 
            self.X = X.to(device)
            self.Y = Y.to(device)

    def sample(self):
        if self.mode == "stochastic":
            X = (torch.randint(2, (self.n, self.d)) - 0.5)*2/self.d**0.5
            Y = (X[:, 0]*X[:, 1]*X[:, 2]*X[:, 3]*(self.d**2)).reshape(self.n,1) 
            return X.to(device), Y.to(device)
        elif self.n == self.N:
            return self.X, self.Y
        else:
            sample_idx = random.sample(range(self.N), k=self.n)
            return self.X[sample_idx, :], self.Y[sample_idx, :]


################################################
class NoisySGD():
    def __init__(self, model, TrainDataLoader, TestDataLoader, eta, lda_L2, lda_0, criteria=nn.BCEWithLogitsLoss(), log_interval = None, annealing = False, entropy_evaluation = False, annealing_interval = 3000):
        self.model = model.to(device)
        self.TrainDataLoader = TrainDataLoader
        self.TestDataLoader = TestDataLoader
        self.criteria = criteria
        self.eta = eta
        self.lda_L2 = lda_L2
        self.lda = lda_0
        self.annealing = annealing
        self.annealing_interval = annealing_interval
        self.entropy_evaluation = entropy_evaluation
        self.log_interval = log_interval if log_interval is not None else 1000

        self.t = 0
        setting = {"d": TrainDataLoader.d, "n": TrainDataLoader.n, "N": TrainDataLoader.N, "n_test": TestDataLoader.n, "lda_L2": lda_L2, "lda_0": lda_0, "eta": eta, "annealing": annealing, "annealing_interval": annealing_interval, "m": model.m, "R": model.R, "init_a_std": model.init_a_std, "init_b_std": model.init_b_std, "init_c_std": model.init_c_std}
        self.log={"setting": setting, "t":[], "train_loss":[], "train_acc":[], "L2":[], "NegativeEntropy": [], "lda": [], "test_loss": [], "test_acc":[]}
        if self.annealing:
            self.average_1 = 0.
            self.average_2 = 0.

        self.record()

    def update(self):
        self.t = self.t + 1
        L2 = 0.
        X, Y = self.TrainDataLoader.sample()
        self.model.zero_grad()
        output = self.model(X)
        train_loss = self.criteria(output, (Y >= 0.0).float())
        train_acc = ((output >= 0.0) == (Y >= 0.0)).float().sum() / len(Y)
        train_loss.backward()
        for param in self.model.parameters():
            if param.requires_grad:
                with torch.no_grad():
                    param -= self.eta * (param.grad * self.model.m + 2 * self.lda_L2 * self.lda * param) + torch.randn_like(param) * (2*self.lda*self.eta) ** 0.5 
                    L2 += torch.norm(param)**2 / self.model.m

        if self.annealing:
            self.average_1 += train_loss.detach() / self.annealing_interval #(train_loss.detach() + self.lda_L2 * L2.detach()) / self.annealing_interval

        if self.t % self.log_interval == 0:
            self.record(train_loss.detach().to("cpu"), train_acc.detach().to("cpu"), L2.detach().to("cpu"))

        if (self.annealing) and (self.t % self.annealing_interval == 0) and (self.t >= self.annealing_interval * 2):
            if self.average_1 < self.average_2 and self.lda >= 0.00001:
                self.lda = self.lda / 1.2
            elif self.lda < 0.1:
                self.lda = self.lda * 1.2
            self.average_2 = self.average_1
            self.average_1 = 0.


    def record(self, train_loss = None, train_acc = None, L2 = None):
        self.log["t"].append(self.t)

        if train_loss is None:
            X, Y = self.TrainDataLoader.sample()
            output = self.model(X)
            train_loss = self.criteria(output, (Y >= 0.0).float())
            train_acc = ((output >= 0.0) == (Y >= 0.0)).float().sum() / len(Y)
        self.log["train_loss"].append(train_loss.detach().to("cpu"))
        self.log["train_acc"].append(train_acc.detach().to("cpu"))

        if L2 is None:
            L2 = 0.
            for param in self.model.parameters():
                L2 += torch.norm(param)**2 / self.model.m
        self.log["L2"].append(L2.detach().to("cpu"))  
        
        if self.entropy_evaluation:
            params = []
            for param in self.model.parameters():
                if param.requires_grad:
                    if len(param.shape) == 1:
                        params.append(copy.deepcopy(param).detach().to("cpu").reshape(param.shape[0],1))
                    elif len(param.shape) == 2:
                        params.append(copy.deepcopy(param).detach().to("cpu"))
            params_all = torch.cat(params, dim=1)
            NegativeEntropy = - continuous.get_h(params_all, k=5) 
        else:
            NegativeEntropy = 0.
        self.log["NegativeEntropy"] = NegativeEntropy

        lda = self.lda
        self.log["lda"].append(lda)

        Xt, Yt = self.TestDataLoader.sample()
        output_t = self.model(Xt)
        test_loss = self.criteria(output_t, (Yt >= 0.0).float())
        test_acc = ((output_t >= 0.0) == (Yt >= 0.0)).float().sum() / len(Yt)
        self.log["test_loss"].append(test_loss.detach().to("cpu"))
        self.log["test_acc"].append(test_acc.detach().to("cpu"))

        print(f"t={self.t}, train_loss: {train_loss}, train_acc: {train_acc}, L2: {L2}, lda: {self.lda}, test_loss: {test_loss}, test_acc: {test_acc}")

################################################
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--d", type=int)
    parser.add_argument("--m", type=int)
    parser.add_argument("--R", type=float)
    parser.add_argument("--n", type=int)
    parser.add_argument("--eta", type=float, default=0.1)
    parser.add_argument("--T", type=int)
    parser.add_argument("--lda_L2", type=float, default=0.001)
    parser.add_argument("--lda", type=float, default=0.001)
    parser.add_argument("--log_interval", type=int, default=1000)   
    parser.add_argument("--annealing", type=int, default=0)   
    parser.add_argument("--annealing_interval", type=int, default=1000)       
    parser.add_argument("--entropy_evaluation", type=int, default=0)
    parser.add_argument("--gpu_id", type=int, default=0)  
    parser.add_argument("--seed", type=int, default=0)    
    args = parser.parse_args()

    args.annealing = bool(args.annealing)
    args.entropy_evaluation = bool(args.entropy_evaluation)

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
        torch.cuda.set_device(args.gpu_id)
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print("GPU is being used")
    else:
        print("GPU is NOT being used")

    t_delta = datetime.timedelta(hours=9)
    JST = datetime.timezone(t_delta, 'JST')
    now = datetime.datetime.now(JST)

    random.seed(args.seed)

    model = MeanFieldNN(args.d, args.m, args.R, init_a_std=1.0, init_b_std=1.0,init_c_std=1.0)
    TrainDataLoader = DataLoader(args.d, args.n, args.n)
    TestDataLoader = DataLoader(args.d, 50000, 50000) #DataLoader(args.d, args.n, args.n)
    Optimizer = NoisySGD(model, TrainDataLoader, TestDataLoader, args.eta, args.lda_L2, args.lda, log_interval=args.log_interval, entropy_evaluation = args.entropy_evaluation, annealing = args.annealing, annealing_interval=args.annealing_interval)

    for t in range(args.T):
        Optimizer.update()

    Optimizer.log["setting"]["T"] = args.T
    Optimizer.log["setting"]["gpu_id"] = args.gpu_id
    Optimizer.log["setting"]["gpu_enabled"] = device

    experiment_id = str(now.year) + str(now.month).zfill(2) + str(now.day).zfill(2) + str(now.hour).zfill(2) + str(now.minute).zfill(2) +str(now.second).zfill(2) + str(now.microsecond).zfill(6) + "." + str(random.randrange(1000000)).zfill(6)
    with open("./Experiment/log.txt", mode="a") as f:
        f.write(experiment_id + ", " + ", ".join(map(str, [args.d, args.m, args.R, args.n, args.eta, args.T, args.lda_L2, args.lda, args.annealing, args.annealing_interval, args.log_interval, args.entropy_evaluation, args.gpu_id, args.seed, Optimizer.log["test_acc"][-1].item(),4])) + "\n")

    with open("./Experiment/log/"+experiment_id+".pickle", "wb") as f:
        pickle.dump(Optimizer.log, f)