import numpy as np 
import torch
from torch import nn 
import torchvision
from matplotlib import pyplot as plt
import math
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torchmin import Minimizer

import random
import copy
import argparse
import sys
import yaml
import time
import pathlib

# args = None
def parse_arguments():
    parser = argparse.ArgumentParser(description="PyTorch single neuron case")
    parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
    parser.add_argument(
        "--epochs",
        default=100000,
        type=int,
        help="number of total epochs to run")

    parser.add_argument("--num-samples", default=100, type=int, help='number of samples')
    parser.add_argument(
        "--lr",
        default=1e-4,
        type=float,
        help="initial learning rate"
    )
    parser.add_argument(
        "--seed", default=1, type=int, help="seed for initializing training. "
    )
    parser.add_argument("--reset-weights", action='store_true', default=False)
    parser.add_argument("--initialize-negative", action='store_true', default=False)
    parser.add_argument("--pruner", type=str, default='mag', help="pruning method")
    parser.add_argument("--dim", type=int, default=10, help="input dim / overparam")
    parser.add_argument("--prune-iters", type=int, default=2, help="num pruning steps")
    parser.add_argument("--noise", type=float, default=0.001, help="noise")
    parser.add_argument("--density", type=float, default=0.5, help="density")
    parser.add_argument("--name", type=str, default='trial', help="name of the experiments")
    parser.add_argument("--optim-method", type=str, default='sgd', help="type of optimizer method")
    # init_case
    parser.add_argument("--init-case", type=int, default='1', help="one of four cases for sign combinations of a, w")

    parser.add_argument("--result-dir", type=str, default='results-finetune-theory', help="result dir of the experiments")

    args = parser.parse_args()

    return args

args = parse_arguments()
print(args)
################ Save Results to CSV
def write_result_to_csv(**kwargs):
    filename = args.result_dir + '.csv'

    results = pathlib.Path(filename)
    
    if not results.exists():
        with open(results, 'w', newline=''):

            results.write_text(
                "Name, "
                "Dim,"
                "Noise,"
                "Density,"
                "Reset Weights,"
                "Train Loss,"
                "Test Loss,"
                "Num Samples,"
                "Weight Init,"
                "Weight Learnt,"
                "A init,"
                "A learnt\n"
            )

    now = time.strftime("%m-%d-%y_%H:%M:%S")

    with open(results, "a+") as f:
        f.write(
            (
                "{name}, "
                "{dim}, "
                "{noise:.05f}, "
                "{density:.02f}, "
                "{reset_weights}, "
                "{train_loss:.05f}, "
                "{test_loss:.05f}, "
                "{num_samples}, "
                "{weight_init:.05f}, "
                "{weight_learnt:.05f},"
                "{a_init:.05f},"
                "{a_learnt:.05f}\n"
            ).format(now=now, **kwargs)
        )


##############

random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# Generate data
#number of samples
n=args.num_samples
#input dim
d=args.dim
#input data drawn from U[-1,1]^d
x = 2*torch.rand([n,d],dtype=float)-1
#additive noise
sigma= args.noise #0.5 #50 #0.5 #0 #0.5
# The true y label is the first dimension input multiplied by the relu
y = F.relu(x[:,0]) + torch.randn(n,dtype=float)*sigma

#input data drawn from U[-1,1]^d
ntest = int(0.2 * n)
xtest = 2*torch.rand([ntest,d],dtype=float)-1
#additive noise
sigma= args.noise #0.5 #50 #0.5 #0 #0.5
# The true y label is the first dimension input multiplied by the relu
ytest = F.relu(xtest[:,0]) + torch.randn(ntest,dtype=float)*sigma



######### Define Model ##########
class LinearER(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(LinearER, self).__init__(in_features, out_features, bias)
        self.register_buffer('weight_mask', torch.ones(self.weight.shape))
        if self.bias is not None:
            self.register_buffer('bias_mask', torch.ones(self.bias.shape))
    
    
    def forward(self, input):
        self.weight_mask = self.weight_mask.to(self.weight.device)
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return F.linear(input, W, b)
    
class neuron(nn.Module):
    def __init__(self, in_dim, out_dim, width):
        super(neuron, self).__init__()
        self.layer1 = LinearER(in_dim, width, bias= False)
        self.layer2 = nn.Linear(width, out_dim, bias = False)
        print(self.layer2.weight.shape)
        self.initialize()
        print(self.layer2.weight.shape)
    def forward(self, x):
        # forward pass
        x = F.relu(self.layer1(x))
        
        x = self.layer2(x)
        
        return x
    
    def initialize(self):
        in_dim, width = self.layer1.weight.shape
        _, out_dim = self.layer2.weight.shape
        
        self.layer1.weight.data = torch.randn((in_dim, width),dtype=float)/np.sqrt(width)
        # self.layer2.weight.data[0] = (torch.rand(1)[0]*2-1) * torch.sqrt(torch.sum(self.layer1.weight.data**2))

        if args.initialize_negative:
            print('weight is initialized as negative')
            self.layer1.weight.data[0,0] = -1

        # w[0,0]=-1
        if args.init_case == 1:
            self.layer1.weight.data[0,0] = 1
            self.layer2.weight.data[0] = 1 * torch.sqrt(torch.sum(self.layer1.weight.data**2))
            
        elif args.init_case == 2:
            self.layer1.weight.data[0,0] = -1
            self.layer2.weight.data[0] = 1 * torch.sqrt(torch.sum(self.layer1.weight.data**2))

        elif args.init_case == 3:
            self.layer1.weight.data[0,0] = -1
            self.layer2.weight.data[0] = -1 * torch.sqrt(torch.sum(self.layer1.weight.data**2))

        elif args.init_case == 4:
            self.layer1.weight.data[0,0] = 1
            self.layer2.weight.data[0] = -1 * torch.sqrt(torch.sum(self.layer1.weight.data**2))

        


model = neuron(d, 1, 1)
model.to(torch.double)
print('The model is defined as')
print(model)
# save the weight of the true weight at init
weight_init = model.layer1.weight.data[0,0].clone()
a_init = model.layer2.weight.data[0].clone().item()
#################
# Defining Magnitude Pruner

def prune_mag(model, density):
    score_list = {}
    for n, m in model.named_modules():
        if isinstance(m, (LinearER)):
            score_list[n] = (m.weight_mask.to(m.weight.device) * m.weight).detach().abs_()

    global_scores = torch.cat([torch.flatten(v) for v in score_list.values()])
    k = int((1 - density) * global_scores.numel())
    threshold, _ = torch.kthvalue(global_scores, k)

    if not k < 1:
        total_num = 0
        total_den = 0
        for n, m in model.named_modules():
            if isinstance(m, (LinearER)):
                score = score_list[n].to(m.weight.device)
                zero = torch.tensor([0.]).to(m.weight.device)
                one = torch.tensor([1.]).to(m.weight.device)
                m.weight_mask = torch.where(score <= threshold, zero, one)
                total_num += (m.weight_mask == 1).sum()
                total_den += m.weight_mask.numel()

    print('Overall model density after magnitude pruning at current iteration = ', total_num / total_den)
    return model

########Train the Model

if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.0) #torch.optim.Adam(model.parameters(), lr=0.00001)
if args.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
if args.optimizer == 'lbfgs':
    optimizer = torch.optim.LBFGS(model.parameters())
    

if args.optimizer == 'newton':
    optimizer = Minimizer(model.parameters(), method='newton-exact', tol=1e-6, max_iter=1000, disp=2)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
loss_criterion = nn.MSELoss()

#before training
print('weights before training')
print(model.layer1.weight)
print(model.layer2.weight)
output = model(x)
torch.save(model.state_dict(),"model_{}_init.pt".format(args.name))
torch.save(optimizer.state_dict(),"optimizer_{}_init.pt".format(args.name))

output=model(x)
outMin = output
minval = 1
best_test_loss = 1


for iters in range(args.prune_iters):
    density = args.density ** ((iters+1) / args.prune_iters)
    for epoch in range(args.epochs):
      loss = []
      def closure():
        optimizer.zero_grad()
        output = model(x)
        curr_loss = loss_criterion(output.flatten(), y.flatten())
        loss.append(curr_loss)
        curr_loss.backward()
        return curr_loss
    #   loss = loss_criterion(output.flatten(), y.flatten())
      
    #   loss.backward()
      optimizer.step(closure)
      if torch.mean(loss[-1]) < minval:
        minval = torch.mean(loss[-1])
        outMin = output

    #   scheduler.step(loss)  
      if epoch%1000==0:
        print("Epoch " + str(epoch) + ": " + str(torch.mean(loss[-1]).item()) + " min: " + str(minval))
        test_out = model(xtest)
        test_loss = loss_criterion(test_out.flatten(), ytest.flatten())
        if torch.mean(test_loss) < best_test_loss:
            best_test_loss = torch.mean(test_loss)

    model = prune_mag(model, density)
    # reset weights
    if args.reset_weights:
        original_dict = torch.load("model_{}_init.pt".format(args.name))
        original_weights = dict(filter(lambda v: (v[0].endswith(('.weight', '.bias'))), original_dict.items()))
        model_dict = model.state_dict()
        model_dict.update(original_weights)
        model.load_state_dict(model_dict)

        # Reset Optimizer and Scheduler
        optimizer.load_state_dict(torch.load("optimizer_{}_init.pt".format(args.name)))
        print('Weights of the model reset to initialization weights')

######## Final Training to Convergence after Pruning
best_train_loss = 1
best_test_loss = 1

print('Final Training Begins..')
for epoch in range(args.epochs):
    loss = []
    def closure():
      optimizer.zero_grad()
      output = model(x)
      curr_loss = loss_criterion(output.flatten(), y.flatten())
      loss.append(curr_loss)
      curr_loss.backward()
      return curr_loss
    #   loss = loss_criterion(output.flatten(), y.flatten())
      
    #   loss.backward()
    optimizer.step(closure)
    if torch.mean(loss[-1]) < minval:
      minval = torch.mean(loss[-1])
      outMin = output
    #   output = model(x)
    #   loss = loss_criterion(output.flatten(), y.flatten())
    #   if torch.mean(loss) < best_train_loss:
    #     best_train_loss = torch.mean(loss)
    #     outMin = output
    #   loss.backward()
    #   optimizer.step()
    #   scheduler.step(loss)  
    if epoch%1000==0:
      print("Epoch " + str(epoch) + ": " + str(torch.mean(loss[-1]).item()) + " min: " + str(minval))
      test_out = model(xtest)
      test_loss = loss_criterion(test_out.flatten(), ytest.flatten())
      if torch.mean(test_loss) < best_test_loss:
        best_test_loss = torch.mean(test_loss)

output = model(x)
print("train loss:")
print(loss_criterion(output.flatten(), y.flatten()))
print("================")
#print(model)
print('Printing Model Weights after training')
print(model.layer1.weight)
print(model.layer2.weight)
weight_learnt = model.layer1.weight.data[0,0].clone()
a_learnt = model.layer2.weight.data[0].clone().item()

print('Weight 1 init and learnt: ', weight_init, weight_learnt)

write_result_to_csv(
        name=args.name,
        dim=args.dim,
        noise=args.noise,
        reset_weights=args.reset_weights,
        train_loss=best_train_loss,
        test_loss=best_test_loss,
        density = args.density,
        num_samples = args.num_samples,
        weight_init = weight_init,
        weight_learnt = weight_learnt,
        a_init = a_init,
        a_learnt = a_learnt
    )