import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os,time
#import cvxpy as cp
import copy

torch.set_printoptions(precision=4)
torch.set_printoptions(linewidth=500)
def get_Hessian_trace(model, A):
    tmp = 0
    for i in range(len(A)):
        Ai = A[i:i+1]
        model.zero_grad()
        output = model(Ai)
        output[0].backward()
        for param in model.parameters():
            tmp += param.grad.norm()**2

    return tmp/len(A)

class Net(nn.Module):
    def __init__(self, M, D, h_dim):
        super(Net, self).__init__()

        self.W = torch.nn.Parameter(torch.randn(D, h_dim),requires_grad=True)
        

    def forward(self, x):
        pred = torch.matmul(x,self.W)
        pred = poly_activation(pred).sum(1)
        return pred

def poly_activation(val):
    return val**3+val


class LinearWarmupLR:
    def __init__(self, optimizer, warmup_epochs, total_epochs):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.current_epoch = 0

    def step(self):
        self.current_epoch += 1
        if self.current_epoch <= self.warmup_epochs:
            lr_scale = self.current_epoch / self.warmup_epochs
        else:
            lr_scale = 1.0  # Constant LR after warmup

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = param_group['initial_lr'] * lr_scale

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--N', type=int, default=5000, help='number of samples')
    parser.add_argument('--D', type=int, default=100, help='number of input dimension D*D')
    parser.add_argument('--M', type=int, default=3, help='number of matrices')
    parser.add_argument('--gt_rank', type=int, default=30, help='rank of ground truth matrix')
    parser.add_argument('--batch_size', type=int, default=2, help='batch size of SGD')
    parser.add_argument('--h_dim', type=int, default=20, help='hidden dimension')
    parser.add_argument('--noise_scale', type=float, default=2, help='variance of label noise')
    parser.add_argument('--n_iter', type=int, default=3000000, help='number of iteration')

    parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    print('Im here')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    print(args)
    torch.manual_seed(args.seed)
    if use_cuda:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")



    # Generate Gaussian data
    N, D = args.N, args.D
    A = torch.randn((N, D))
    A = (A.T/ A.norm(dim =1) ).T
    A_test = torch.randn((N, D))
    A_test = (A_test.T/ A_test.norm(dim =1) ).T




    model_gt = Net(args.M, args.D, 1)
    b = model_gt(A).detach()
    b_test = model_gt(A_test).detach()
    A, b = A.to(device), b.to(device)
    A_test, b_test = A_test.to(device), b_test.to(device)
    gt_tr = get_Hessian_trace(model_gt,A)

    log_file_path = "repeated_log_Final_hidden_dim{}_D{}_N{}_SEED{}_lr{}_noiseScale{}_batchSize{}".format(args.h_dim, args.D, args.N, args.seed, args.lr, args.noise_scale, args.batch_size)

    model = Net(args.M, args.D, args.h_dim).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr)
    for param_group in optimizer.param_groups:
        param_group['initial_lr'] = param_group['lr']

    #print(A,A.norm(dim=1), b)

    scheduler = LinearWarmupLR(optimizer, warmup_epochs=1000000, total_epochs=args.n_iter)
    counter = 0
    for it in range(args.n_iter):
        optimizer.zero_grad()
        idx = torch.randperm(A.size(0))[:args.batch_size]
        A_i, b_i = A[idx], b[idx]

        b_noise = (torch.randn(len(b_i)) * args.noise_scale).to(device)
        b_i = b_i + b_noise

        old_W = model.W.detach().clone()
        output = model(A_i)
        loss = ((output - b_i)**2).mean()

        loss.backward()
        scheduler.step()
        optimizer.step()
        if (it % 2000) == 0 and it < 1100000:
            activation = poly_activation(torch.matmul(A,model.W)).T
            activation_indices = torch.sort(activation[:, 0])[1]
            sorted_activation = activation[activation_indices]
            U, S, V = np.linalg.svd(activation.detach(), full_matrices=True)
            main_comp = np.transpose(U)[0:3, :]
            print('main comp is ', main_comp)
            ## counting number of diff neuron features
            count = 1
            for i in range(len(activation_indices) - 1):
                if abs(activation_indices[i] - activation_indices[i+1]) > 0.0001:
                    count += 1

            #print(f'model:{model.W}, \ngrad: {model.W.grad}, \nmodel_gt:{model_gt.W}\n,data:{A}, \n hidden_neurons:\n{sorted_activation.detach()}')
            ht = get_Hessian_trace(model, A)


            output_train = model(A)
            output_test = model(A_test)
            loss_train = ((output_train - b)**2).mean()
            loss_test = ((output_test - b_test)**2).mean()

            # # Loss Log
            # print("---------------------------------------------------------")
            #print(f'Iter: {it}, Learning rate: {optimizer.param_groups[0]["lr"]:.5f}, Train Loss: {loss_train}, Test Loss: {loss_test}, Hessian Trace: {ht}, gt trace: {gt_tr}')#nunorm.item(), ht))
            f = open(log_file_path, "a")
            f.write("---------------------------------------------------------\n")
            f.write('Iter: {}, Train Loss: {}, Test Loss: {}, Hessian Trace: {}, Components First: {}, Components Second: {}, Components Third: {}\n'.format(it, loss_train, loss_test, ht, torch.from_numpy(main_comp[0]), torch.from_numpy(main_comp[1]), torch.from_numpy(main_comp[2])))#nunorm.item(), ht))
            f.close()




if __name__ == '__main__':
    main()
