from pickletools import optimize
from re import A
import torch
import numpy as np
# import numpy as np
import torch.nn as nn
from IPython import embed
from .network import ScalarNetVec, ScalarNetVecTorch

from tqdm import tqdm

def quadratic_loss(Y_hat, Y):
    # """Batched quadratic loss"""
    # diff = torch.square(Y - Y_hat)
    # loss = diff.sum() * 0.5
    # if len(Y_hat.shape) != 1:
    #     loss /= len(Y)
    # return loss
    return (Y_hat - Y) ** 2 / 2


def train_net_vec(weights, eta, epochs, mu, verbose=False, sharpness=True, traj_comp=False, record_epochs=15, torch=False, progress=False):
    
    net = ScalarNetVec(weights) if not torch else ScalarNetVecTorch(weights)
    losses = []
    sharpnesses = []
    traj = []

    iter_base = tqdm(range(epochs)) if progress else range(epochs)
    for i in iter_base:
        if net.diverge():
            return -1, losses, sharpnesses, traj
        try:
            if traj_comp or i + record_epochs > epochs:
                traj.append(net.weight_clone())
                sharpness = net.sharpness(mu)
                sharpnesses.append(sharpness)
                loss = net.loss(1, mu)
                losses.append(loss)
            grad = net.gradient_comp(mu)
            net.weights -= grad * eta
        except:
            return -1, losses, sharpnesses, traj
    threshold = 2/eta
    diff = threshold - np.array(sharpnesses[-10:])
    diff2 = diff[1:] - diff[:-1]
    eos = 0
    # print(diff[:-1]*diff[1:])
    if abs(np.mean(diff)) < threshold / 20 and np.all(diff2[:-1]*diff2[1:] < 0):
        eos = 0
    elif np.mean(np.abs(np.abs(net.weight_clone()) - np.power(mu, 1/len(net.weights)))) < 0.001:
        eos = 2
    elif np.mean(diff) > 0:
        eos = 1
    else:
        eos = -1
    if verbose:
        print(net.weights), sharpnesses[-1]
    return eos, losses, sharpnesses, traj

def train_net_vec_traj(weights, eta, epochs, mu=1, traj_freq=99, progress=False, torch=False):
    
    net = ScalarNetVec(weights) if not torch else ScalarNetVecTorch(weights)
    traj = []

    iter_base = tqdm(range(epochs)) if progress else range(epochs)
    for i in iter_base:
        if net.diverge():
            return False, traj
        try:
            if i % traj_freq == 0:
                traj.append(net.weight_clone())
            grad = net.gradient_comp(mu)
            net.weights -= grad * eta
        except:
            return False, traj
    return True, traj


def train_net_4num_simp(weights, eta, epochs, traj_freq=99, progress=False):
    weights = np.array(weights)
    traj = []

    iter_base = tqdm(range(epochs)) if progress else range(epochs)
    for i in iter_base:
        if i % traj_freq == 0:
            traj.append(np.copy(weights))
        x, y = weights
        xy = x * y
        c2 = xy * xy - 1
        weights[0] -= eta * y * xy * c2
        weights[1] -= eta * x * xy * c2
    return True, traj 
