import torch
import tqdm
from sam import SAM
import argparse
import os
import numpy as np
import copy
from torch.nn.modules.batchnorm import _BatchNorm


def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum

    model.apply(_enable)


def compute_trace(network, dataset, has_bias):
    # work for MLP and MSE
    train_dataset = dataset[0]
    trace = 0
    network.zero_grad()
    for i in range(len(train_dataset)):
        input, label = train_dataset[i]
        input = input.cuda()
        label = label.cuda()
        input = input.unsqueeze(dim = 0)
        label = label.unsqueeze(dim = 0)
        output = forward(network, input, label, full_dataset = train_dataset, return_output = True)
        output.backward()
        for component in network.network:
            if(isinstance(component, torch.nn.Linear) or isinstance(component, ConstantScale)):
                trace += component.weight.grad.norm() ** 2
                if(component.bias is not None):
                    trace += component.bias.grad.norm() ** 2
            if(isinstance(component, torch.nn.LayerNorm)):
                if(component.weight is not None):
                    trace += component.weight.grad.norm() ** 2
                if(component.bias is not None):
                    trace += component.bias.grad.norm() ** 2
        network.zero_grad()
    trace /= len(train_dataset)
    trace *= 2
    return float(trace)
@torch.no_grad()
def compute_trace_old(network, dataset, has_bias):
    train_dataset = dataset[0]
    trace = 0
    a = network.network[2].weight
    for i in range(len(train_dataset)):
        input, label = train_dataset[i]
        input = input.cuda()
        activation = network.network[1](network.network[0](input))
        trace += activation.norm()**2
        mask = activation > 0
        trace += (a*mask).norm()**2 * (input.norm()**2 + has_bias)
    trace /= len(train_dataset)
    trace *= 2
    return trace
@torch.no_grad()
def min_ground_truth_trace(train_dataset, has_bias):
    min_trace = 0
    A = 0
    B = 0
    for input, label in train_dataset:
        A += (input[0]**2 + input[1]**2)
        B += 2*(input.norm()**2 + has_bias)
    min_trace = torch.sqrt(A*B)
    min_trace *= (4 / len(train_dataset))
    return min_trace
def predict_trace(train_dataset):
    min_trace = 0
    for i, (input, label) in enumerate(train_dataset):
        min_trace += (label).norm()*torch.sqrt(input.norm()**2 + 1)
    min_trace *= 2
    min_trace /= len(train_dataset)
    min_trace *= 2
    return min_trace

parser = argparse.ArgumentParser("Argument Parser for Trace Experiment")
# data
parser.add_argument("--dimension", default = 30, type = int)
parser.add_argument("--train_number", default = 100, type = int)
parser.add_argument("--valid_number", default = 1000, type = int)
parser.add_argument("--data_type", default = "boolean", type = str)
# seed
parser.add_argument("--seed", default = 0, type = int)
# training
parser.add_argument("--epochs", default = 100000, type = int)
parser.add_argument("--bz", default = 1, type = int)
parser.add_argument("--lr", default = 1e-2, type = float)
parser.add_argument("--wd", default = 0, type = float)
# optimizer
parser.add_argument("--label_noise_strength", "--sigma", default = 0.0, type = float)
parser.add_argument("--rho", default = 5e-2, type = float)
parser.add_argument("--lambda_trace", default = 0, type = float)
parser.add_argument("--kernel", action = "store_true")
# model
parser.add_argument("--epsilon", default = 0.0, type = float)
parser.add_argument("--pre_norm", action = "store_true")
parser.add_argument("--has_bias", action = "store_true")
parser.add_argument("--normalize",action = "store_true")
parser.add_argument("--raw_normalize",action = "store_true")
parser.add_argument("--batch_normalize",action = "store_true")
parser.add_argument("--no_affine", action = "store_true")
parser.add_argument("--addition_affine", action = "store_true")
parser.add_argument("--addition_beta", action = "store_true")
parser.add_argument("--depth", default = 1, type = int)
parser.add_argument("--width", default = 500, type = int)
parser.add_argument("--save_to", default = "trace_exp/final/regression_boolean_bias_simple_ln_wd", type = str)
parser.add_argument("--activation", default = "relu", type = str)
parser.add_argument("--beta", default = 10, type = float)
parser.add_argument("--NTK", action = "store_true")
parser.add_argument("--scale", default = 1 , type = float)
parser.add_argument("--load_dataset", default = "", type = str)
parser.add_argument("--resume_network", default = "", type = str)
parser.add_argument("--global_bn", action = "store_true")
parser.add_argument("--global_bn_no_mean", action = "store_true")
parser.add_argument("--renormalization", action="store_true")
parser.add_argument("--renormalization_norm", default = 1, type = float)

# criterion
parser.add_argument("--loss", default = "mse", type = str)
parser.add_argument("--soft_label", default = 0.2, type = float)
args = parser.parse_args()

print(args)
# data
dimension = args.dimension
train_number = args.train_number
valid_number = args.valid_number
# seed
torch.random.manual_seed(seed = args.seed)
# training
epochs = args.epochs
bz = args.bz
lr = args.lr
wd = args.wd

# optimizer
label_noise_strength = args.label_noise_strength
rho = args.rho
lambda_trace = args.lambda_trace
use_sam = rho > 0
use_trace = args.lambda_trace > 0


# model
has_bias = args.has_bias

class RawNormalize(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x / (x.norm(dim = -1, keepdim = True) + 1e-6)



class ConstantScale(torch.nn.Module):
    def __init__(self, width, has_bias = False):
        super().__init__()
        self.weight = torch.nn.parameter.Parameter(data = torch.ones((width,)))
        self.has_bias = has_bias
        if(has_bias):
            self.bias = torch.nn.parameter.Parameter(data = torch.zeros((1, width)))
        else:
            self.bias = None
    def forward(self, x):
        if(self.has_bias):    
            return self.weight * x + self.bias
        else:
            return self.weight * x

class sReLU(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        return torch.cat([self.relu(x), self.relu(-x)], dim  = -1)
# Model
class MyModel(torch.nn.Module):
    def __init__(self, dimension, width, has_bias = True, depth = 1, normalize = False,\
        activation = "relu", raw_normalize = False, batch_normalize = False, no_affine = False, addition_affine = False, epsilon = 1e-5, pre_norm = False, addition_beta = False):
        super().__init__()
        act_function = {"relu":torch.nn.ReLU, "softplus":torch.nn.Softplus, "srelu": sReLU}[activation]
        modules = [torch.nn.Linear(dimension, width, bias = has_bias)]
        # Normalization
        def __add_normalize__():
            if(normalize):
                modules.append(torch.nn.LayerNorm([width], elementwise_affine = not no_affine, eps = epsilon))
            if(batch_normalize):
                modules.append(torch.nn.BatchNorm1d([width], affine = not no_affine, eps = epsilon))
            if(raw_normalize):
                modules.append(RawNormalize())
        # Activation
        def __add_activation__():
            if(activation == "relu"):
                modules.append(act_function())
            elif(activation == "srelu"):
                modules.append(act_function())
                width = width * 2
            elif(activation == "softplus"):
                modules.append(act_function(beta = args.beta, threshold = 1))
        
        if(pre_norm):
            __add_normalize__()
        __add_activation__()
        for d_index in range(depth - 1):
            if(not pre_norm):
                __add_normalize__()
            modules.append(torch.nn.Linear(width, width, bias = False))
            if(pre_norm):
                __add_normalize__()
            __add_activation__()
        if(not pre_norm):
            __add_normalize__()
        if(addition_affine):
            modules.append(ConstantScale(width, has_bias = addition_beta))
        modules.append(torch.nn.Linear(width, 1, bias = False))
        self.network = torch.nn.Sequential(*modules)
    def forward(self, x):
        return self.network(x)

if(os.path.exists(args.save_to)):
    print("{} already exists".format(args.save_to))
    raise NotADirectoryError
else:
    os.makedirs(args.save_to)


if(args.data_type == 'gaussian'):
    X = torch.randn(train_number + valid_number, dimension)
    X = X/X.norm(dim = 1).unsqueeze(dim = 1)
    X = X*torch.sqrt(torch.tensor([dimension]))
    Y = abs(X[:,0]) - abs(X[:,1])
elif(args.data_type == 'boolean'):
    rand_matrix = torch.ones((train_number + valid_number, dimension))
    rand_matrix = rand_matrix * 0.5
    X = 2 * torch.bernoulli(rand_matrix) - 1
    X = X.float()
    Y = (X[:,0]) * (X[:,1]) > 0
    Y = Y.float()
    if(args.loss == 'mse'):
        Y = 2 * Y - 1
dataset = torch.utils.data.TensorDataset(X, Y)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_number, valid_number])

if(args.load_dataset):
    train_dataset, test_dataset = torch.load(args.load_dataset)
    train_number = len(train_dataset)
    dimension = len(train_dataset[0][0])
    print(dimension)


train_loader = torch.utils.data.DataLoader(train_dataset, shuffle = True, batch_size = bz)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle = False, batch_size = 10)
if(args.width > 0):
    width = args.width
else:
    width = train_number

model = MyModel(dimension, width, has_bias = has_bias, depth = args.depth, normalize = args.normalize,\
    activation = args.activation, raw_normalize = args.raw_normalize, no_affine = args.no_affine, addition_affine = args.addition_affine, epsilon = args.epsilon, pre_norm = args.pre_norm,
    batch_normalize = args.batch_normalize, addition_beta = args.addition_beta)
scale = args.scale
if(scale != 1):
    weight = model.network[0].__dict__['_parameters']['weight'] * scale
    if(has_bias):
        model.network[0].__dict__['_parameters']['bias'] = torch.nn.Parameter(model.network[0].__dict__['_parameters']['bias'].data.detach()  * scale)
    model.network[0].__dict__['_parameters']['weight'] = torch.nn.Parameter(weight)
if(args.NTK):
    assert (width%2 == 0)
    weight = model.network[0].__dict__['_parameters']['weight'] * scale
    if(has_bias):
        model.network[0].__dict__['_parameters']['bias'] = torch.nn.Parameter(model.network[0].__dict__['_parameters']['bias'].data.detach()  * scale)
    model.network[2].__dict__['_parameters']['weight'] = torch.nn.Parameter(model.network[2].__dict__['_parameters']['weight'].data.detach() * scale)
    weight = copy.deepcopy(weight.data.detach())
    first_half = list(range(width//2))
    second_half = np.array(first_half) + width//2
    weight[second_half] = -weight[first_half]
    model.network[0].__dict__['_parameters']['weight'] = torch.nn.Parameter(weight)
if(args.resume_network):
    model = torch.load(args.resume_network)
    
model = model.cuda()

if(args.loss == 'mse'):
    criterion = torch.nn.MSELoss()
elif(args.loss == 'logistic'):
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.soft_label)
elif(args.loss == 'kl'):
    criterion = torch.nn.KLDivLoss()
if(not args.kernel):
    parameters = model.parameters()
    no_decay = []
    decay = []
    for name, parameter in model.named_parameters():
        if(parameter.shape[0] == 1 and 'weight' in name):
            no_decay.append(parameter)
        else:
            decay.append(parameter)
else:
    parameters = model.network[-1].parameters()
    



    
if(use_sam):
    optimizer = SAM([{'params': no_decay, 'weight_decay': 0}, {'params': decay, 'weight_decay': wd}], torch.optim.SGD, rho = rho, lr = lr)
else:
    optimizer = torch.optim.SGD(parameters, lr = lr, weight_decay = wd)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

# # raw sgd
def forward(model, x, y, return_correct = False, full_dataset = None, return_output = False, validation = False):
    if(args.global_bn):
        x_list = []
        for x_prime, y_prime in full_dataset:
            x_list.append(x_prime)
        x_list = torch.stack(x_list).to(x.device)
        minus_mean = not args.global_bn_no_mean
        for idx, component in enumerate(model.network):
            x_list = component(x_list)
            x = component(x)
            if((isinstance(component, torch.nn.ReLU) or isinstance(component, torch.nn.Softplus) or isinstance(component, sReLU))and not args.pre_norm):
                if(minus_mean):
                    x = (x - x_list.mean(dim = 0, keepdim = True)) / (x_list.std(dim = 0, keepdim = True) + args.epsilon)
                    x_list = (x_list - x_list.mean(dim = 0, keepdim = True)) / (x_list.std(dim = 0, keepdim = True) + args.epsilon)
                else:
                    x = x / (x_list.norm(dim = 0, keepdim = True) + args.epsilon)
                    x_list = x_list / (x_list.norm(dim = 0, keepdim = True) + args.epsilon)
            if(isinstance(component, torch.nn.Linear) and args.pre_norm):
                if(minus_mean):
                    x = (x - x_list.mean(dim = 0, keepdim = True)) / (x_list.std(dim = 0, keepdim = True) + args.epsilon)
                    x_list = (x_list - x_list.mean(dim = 0, keepdim = True)) / (x_list.std(dim = 0, keepdim = True) + args.epsilon)
                else:
                    x = x / (x_list.norm(dim = 0, keepdim = True) + args.epsilon)
                    x_list = x_list / (x_list.norm(dim = 0, keepdim = True) + args.epsilon)
        preds = x
    else:
        preds = model(x)
    
    if return_output:
        return preds
    
    
    if(args.loss == "logistic"):
        hard_preds = preds > 0
        new_preds = torch.zeros((preds.shape[0], 2), device = preds.device)
        new_preds[:,1] = preds.squeeze(dim = 1)
        preds = new_preds
        y = y.long()
        loss = criterion(preds, y)
        if(return_correct):
            return loss, (hard_preds.squeeze() == y.squeeze()).sum()
        return loss
    elif(args.loss in ["mse"]):
        if(args.data_type == 'gaussian'):
            hard_preds = preds > 0.5
            y = y.reshape([-1,1])
            if(label_noise_strength > 0 and not validation):
                y = y + label_noise_strength*torch.randn(y.shape, device = torch.device('cuda'))
            loss = criterion(preds, y)
            if(return_correct):
                return loss, (hard_preds.squeeze() == y.squeeze()).sum()
            return loss
        elif(args.data_type == 'boolean'):
            hard_preds = preds > 0
            y = y.reshape([-1,1])
            gt_preds = y > 0
            if(label_noise_strength > 0 and not validation):
                y = y + label_noise_strength*torch.randn(y.shape, device = torch.device('cuda'))
            loss = criterion(preds, y)
            if(return_correct):
                # from IPython import embed;embed()
                return loss, (hard_preds.squeeze() == gt_preds.squeeze()).sum()
            return loss
    elif(args.loss == "kl"):
        y = y.reshape([-1,1])
        p = torch.exp(y) / (torch.exp(y) + 1)
        loss = (p * torch.log(p) + (1 - p) * torch.log(1 - p)) - (p * preds - torch.log(torch.exp(preds) + 1))
        loss = loss.mean()
        # if(loss < 1e-5):
        #     from IPython import embed;embed()
        if(args.data_type == "boolean"):
            hard_preds = preds > 0.5
        elif(args.data_type == 'gaussian'):
            hard_preds = preds > 0
            y = y > 0
        if(return_correct):
            return loss, (hard_preds.squeeze() == y.squeeze()).sum()
        return loss
def get_trace(model, x, y):
    assert (len(model.network) == 3)
    a = model.network[2].weight
    trace = 0
    for i in range(x.shape[0]):
            input = x[i]
            input = input.cuda()
            activation = model.network[1](model.network[0](input))
            trace += activation.norm()**2
            mask = activation > 0
            trace += (a*mask).norm()**2 * (input.norm()**2 + has_bias)
    trace /= x.shape[0]
    trace *= 2
    return trace
                    


valid_loss_list = []
train_loss_list = []
train_acc_list = []
valid_acc_list = []
trace_list = []
norm_list = []
with tqdm.trange(epochs) as t:
    for epoch in t:
        # train
        model.train()
        for x,y in train_loader:
            x = x.cuda()
            y = y.cuda()
            if(use_trace):
                trace = get_trace(model, x, y) * lambda_trace
                trace.backward()
            enable_running_stats(model)
            loss = forward(model, x, y, full_dataset=train_dataset)
            loss.backward()
            if(use_sam): 
                optimizer.first_step(zero_grad = True)
                disable_running_stats(model)
                loss2 = forward(model, x, y, full_dataset=train_dataset)
                loss2.backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.step()
                model.zero_grad()
            with torch.no_grad():
                if(args.renormalization):
                    model.network[0].weight.data = model.network[0].weight / model.network[0].weight.norm() * args.renormalization_norm
                    if(args.has_bias):
                        model.network[0].bias.data = model.network[0].bias / model.network[0].bias.norm() * args.renormalization_norm
                        
        # validate
        model.eval()
        trace = compute_trace(model, [train_dataset, test_dataset], has_bias)
        trace_list.append(float(trace))
        correct_cnt = 0
        with torch.no_grad():
            train_loss = 0
            for x, y in train_loader:
                x = x.cuda()
                y = y.cuda()
                loss, cnt = forward(model, x, y, True, full_dataset=train_dataset, validation = True)
                train_loss += loss
                correct_cnt += cnt
            train_loss /= len(train_loader)
            train_acc = correct_cnt.float()/ len(train_dataset)
        train_loss_list.append(float(train_loss.detach()))
        train_acc_list.append(float(train_acc.detach()))
        norm_list.append(float(model.network[0].weight.norm().detach()))
        correct_cnt = 0
        with torch.no_grad():
            valid_loss = 0
            for x, y in test_loader:
                x = x.cuda()
                y = y.cuda()
                loss, cnt  = forward(model, x, y, True, full_dataset=train_dataset, validation = True)
                correct_cnt += cnt
                valid_loss += loss
            valid_acc = correct_cnt.float() / len(test_dataset)
            valid_loss /= len(test_loader)
        valid_loss_list.append(float(valid_loss.detach()))
        valid_acc_list.append(float(valid_acc.detach()))
        t.set_postfix(valid = valid_loss_list[-1], train = train_loss_list[-1], trace = trace_list[-1], acc = valid_acc_list[-1], norm = norm_list[-1])
    # scheduler.step()
torch.save(norm_list, os.path.join(args.save_to, 'norm'))
torch.save(train_loss_list, os.path.join(args.save_to, 'train_loss'))
torch.save(train_acc_list, os.path.join(args.save_to, 'train_acc'))
torch.save(valid_loss_list, os.path.join(args.save_to, 'valid_loss'))
torch.save(valid_acc_list, os.path.join(args.save_to, 'valid_acc'))
torch.save(trace_list, os.path.join(args.save_to, 'trace'))
torch.save(model, os.path.join(args.save_to, 'final_model'))
torch.save((train_dataset, test_dataset), os.path.join(args.save_to, 'dataset'))
def get_valid_error(model, valid_dataset):
    count = 0
    for input, label in valid_dataset:
        predict = model(input) > 0
        true_label = label > 0
        count += (predict != true_label)
    return float(count / len(valid_dataset))


train_num = len(train_dataset)
network = MyModel(dimension, train_num, True)
weight = torch.zeros(train_num, dimension)
b = torch.zeros(train_num)
a = torch.zeros(1, train_num)
for i, (input, label) in enumerate(train_dataset):
    dynamic = torch.sqrt(torch.sqrt(input.norm()**2 + 1) * (label.float()).norm())
    weight[i] = input
    if(args.data_type == 'boolean'):
        weight[i] -= torch.ones(dimension) * 0.5
    b[i] = - torch.dot(weight[i], input) + 1e-2*input.norm()
    C = torch.dot(weight[i], input) + b[i]
    weight[i] *= (dynamic/C)
    b[i] *= (dynamic/C)
    if(abs(dynamic) < 1e-5):
        a[0, i] = 0
    else:
        a[0, i] = label / (dynamic)
network.network[0].__dict__['_parameters']['bias'] = torch.nn.Parameter(b)
network.network[0].__dict__['_parameters']['weight'] = torch.nn.Parameter(weight)
network.network[2].__dict__['_parameters']['weight'] = torch.nn.Parameter(a)
loss_list = []
for (input, label) in train_dataset:
    loss_list.append((network(input) - label)**2)
loss_list = torch.stack(loss_list)
loss = float(loss_list.mean())
valid_loss_list = []
for (input, label) in test_dataset:
    valid_loss_list.append((network(input) - label)**2)
valid_loss_list = torch.stack(valid_loss_list)
valid_loss = float(valid_loss_list.mean())
valid_error = get_valid_error(network, test_dataset)
print(loss, valid_loss)
network.cuda()
real_trace = float(compute_trace(network,[train_dataset, test_dataset], has_bias))
estimate_trace = float(predict_trace(train_dataset))
estimate_ground_truth_trace = float(min_ground_truth_trace(train_dataset, has_bias))
print(real_trace, estimate_trace)

torch.save(loss, os.path.join(args.save_to, 'memorize_train_loss'))
torch.save(valid_error, os.path.join(args.save_to, 'memorize_valid_error'))
torch.save(valid_loss, os.path.join(args.save_to, 'memorize_valid_loss'))
torch.save(real_trace , os.path.join(args.save_to, 'memorize_trace'))
torch.save(estimate_trace, os.path.join(args.save_to, 'predict_memorize_trace'))
torch.save(estimate_ground_truth_trace, os.path.join(args.save_to, 'min_ground_truth_trace'))
torch.save(args,  os.path.join(args.save_to, 'args'))

