from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = '14'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from net import model
from data import *
from vis import *


def creat_network(args, in_dim, out_dim):
    network = model(in_dim, args.hid_width, out_dim, args.depth, args.relu, args.bias, args.init)
    print(network)
    return network.to(device)


def loss_func(args, y, y_hat):
    if args.loss == "mse":
        return 0.5*nn.MSELoss(reduction=args.reduction)(y, y_hat)
    elif args.loss == "logistic":
        if args.reduction == 'mean':
            return torch.mean(torch.log(1+torch.exp(-y * y_hat)))
        elif args.reduction == 'sum':
            return torch.sum(torch.log(1+torch.exp(-y * y_hat)))
    elif args.loss == "exponential":
        return torch.mean(torch.exp(-y * y_hat))
    else:
        raise NotImplementedError


def train(data, args):
    x_tensor, y_tensor, in_dim, out_dim = prep_data(args, data, device)
    network = creat_network(args, in_dim, out_dim)
    optimizer = optim.SGD(network.parameters(), lr=args.lr, weight_decay=args.reg*(args.relu+1))
    
    weights = np.zeros((args.epoch, in_dim)) if args.track_weight else None
    results = {'Ls': np.zeros(args.epoch),
               'W': weights,
               'layer': []}

    # Training loop
    for i in range(args.epoch):
        optimizer.zero_grad()
        predictions = network(x_tensor)
        loss = loss_func(args, y_tensor, predictions)
        loss.backward()
        optimizer.step()
        results['Ls'][i] = loss.item()
        
        if args.track_weight:
            W = [param.data for param in network.parameters()]
            if i == args.epoch -1:
                for l in range(len(W)):
                    np.save("param_{}.npy".format(l), W[l].cpu().detach().numpy())
                    print(W[l].shape)
            results['layer'].append(deepcopy(W))

    if args.sweep == 'single' and args.reduction == 'mean':
        np.savetxt("loss.txt", results['Ls'])
        plot_training(args, data, results)
        if args.data in ['fan', 'circle', 'toy']:
            vis_classify(args, data, network, i)
        plt.show()
    return results
