import torch
from torch.autograd import grad 
from torch.nn import functional as F
from src.numpy_ops import Isotonic
import numpy as np
from src import projection



def TTSA_poison(model, x_poi, y_p,  x_tr, y_tr, x_v, y_v, criterion, eta, Q, rw,
                                              delta, N):
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    d = len(params)
    batch_size_poisoned = len(x_poi)
    batch_size_clean = len(x_tr)
    v_len = len(x_v)
    val_outpus = model(x_v)
    ce_loss_val = criterion(val_outpus, y_v)
    f = -ce_loss_val / v_len
    Fy_gradient = torch.autograd.grad(f, model.parameters())
    v = torch.unsqueeze(torch.cat([torch.reshape(fy_g, [-1]) for fy_g in Fy_gradient]), 1).detach()
    cQ = np.random.randint(Q)
    zo_num = 3
    for i in range(cQ+1):
        poison_outputs = model(x_poi)
        ce_loss_poison = criterion(poison_outputs, y_p)
        train_outputs = model(x_tr)
        ce_loss_train = criterion(train_outputs, y_tr)
        g = (ce_loss_train + ce_loss_poison) / (batch_size_poisoned + batch_size_clean)
        Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
        Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
        z = params - eta * Gy_gradients
        # with torch.no_grad():
        #     v_tmp = 0
        #     for j in range(zo_num):
        #         u = torch.randn_like(z)
        #         Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
        #         Pzv_lower = projection.projection_l2_ball(z, max_value=rw)
        #         v_tmp += 1 / ( delta) * (
        #                 torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
        #     v = v_tmp/ zo_num
        with torch.no_grad():
            u = torch.randn_like(z)
            # u = u / torch.norm(u)#*np.sqrt(d+1)
            Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
            Pzv_lower = projection.projection_l2_ball(z, max_value=rw)
            v = 1 / ( delta) * (
                    torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
        G = torch.sum(z * v)
        GYG = torch.autograd.grad(G, model.parameters())
        v = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in GYG]), 1)
    poison_outputs = model(x_poi)
    ce_loss_poison = criterion(poison_outputs, y_p)
    train_outputs = model(x_tr)
    ce_loss_train = criterion(train_outputs, y_tr)
    g = (ce_loss_train + ce_loss_poison) / (batch_size_poisoned + batch_size_clean)
    Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
    Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
    z = params - eta * Gy_gradients
    # with torch.no_grad():
    #         v_tmp = 0
    #         for j in range(zo_num):
    #             u = torch.randn_like(z)
    #             Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
    #             Pzv_lower = projection.projection_l2_ball(z, max_value=rw)
    #             v_tmp += 1 / ( delta) * (
    #                     torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
    #         v = v_tmp/ zo_num
    with torch.no_grad():
        u = torch.randn_like(params)
        # u = u / torch.norm(u)#*np.sqrt(d+1)
        Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
        Pzv_lower = projection.projection_l2_ball(z , max_value=rw)
        v = 1 / ( delta) * (
                torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
    G = torch.sum(Gy_gradients * v)
    Gyx_gradient = torch.autograd.grad(G, x_poi)[0]
    outer_update = -eta * Q * Gyx_gradient
    return outer_update


def TTSA_denoise_lasso(model, weight, train_x, train_y, val_x, val_y, out_f, inner_g, eta, Q=3, rw=0.1, delta=1e-8,N=10):
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    d = len(params)
    f = out_f(model, val_x, val_y)
    Fy_gradient = torch.autograd.grad(f, model.parameters())
    v = torch.unsqueeze(torch.cat([torch.reshape(fy_g, [-1]) for fy_g in Fy_gradient]), 1).detach()

    cQ = np.random.randint(Q)
    for i in range(cQ+1):
        g = inner_g(model, weight, train_x, train_y)
        Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
        Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
        z = params - eta * Gy_gradients
        with torch.no_grad():
            u = torch.randn_like(z)
            # u = u/torch.norm(u)
            Pzv_upper = projection.projection_l1_ball(z + delta * u, max_value=rw)
            Pzv_lower = projection.projection_l1_ball(z , max_value=rw)
            v = 1 /(delta) * (
                    torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
        G = torch.sum(z * v)
        v = torch.unsqueeze(torch.reshape(torch.autograd.grad(G, model.parameters())[0], [-1]), 1)
    g = inner_g(model, weight, train_x, train_y)
    Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
    Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
    z = params - eta * Gy_gradients
    with torch.no_grad():
        u = torch.randn_like(z)
        # u = u/torch.norm(u)
        # u = u / torch.norm(u)  # *np.sqrt(d+1)
        Pzv_upper = projection.projection_l1_ball(z + delta * u, max_value=rw)
        Pzv_lower = projection.projection_l1_ball(z, max_value=rw)
        v = 1 / ( delta) * (
                torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
    G = torch.sum(Gy_gradients * v)
    Gyx_gradient = torch.autograd.grad(G, weight)[0]
    outer_update = -eta * Q * Gyx_gradient
    return outer_update


def TTSA_AA(model, input, trans_images_list, target, input_search, target_search,
                               criterion,eta, Q, rw, delta):
    torch.cuda.empty_cache()
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    d = len(params)
    aug_params = torch.unsqueeze(
        torch.cat([torch.reshape(param, [-1]) for param in model.augment_parameters()]), 1)
    f = model._loss(input_search, target_search)+0.01*torch.norm(aug_params)
    Fy_gradient = torch.autograd.grad(f, model.parameters(),retain_graph=True)
    Fx_gradient = torch.autograd.grad(f, model.augment_parameters(), retain_graph=True)
    v = torch.unsqueeze(torch.cat([torch.reshape(fy_g, [-1]) for fy_g in Fy_gradient]), 1).detach()
    vx = torch.unsqueeze(torch.cat([torch.reshape(fx_g, [-1]) for fx_g in Fx_gradient]), 1).detach()
    cQ = np.random.randint(Q)
    for i in range(cQ):
        logits = model(input, trans_images_list)
        g = criterion(logits, target)
        Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
        Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
        z = params - eta * Gy_gradients
        with torch.no_grad():
            u = torch.randn_like(z)
            u = u / torch.norm(u)  # *np.sqrt(d+1)
            Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
            Pzv_lower = projection.projection_l2_ball(z - delta * u, max_value=rw)
            v = d / (2 * delta) * (
                    torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
        G = torch.sum(z * v)
        GYG = torch.autograd.grad(G, model.parameters())
        v = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in GYG]), 1)
        # v = torch.unsqueeze(torch.reshape(torch.autograd.grad(G, model.parameters())[0], [-1]), 1)
    logits = model(input, trans_images_list)
    g = criterion(logits, target)
    Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
    Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
    z = params - eta * Gy_gradients
    with torch.no_grad():
        u = torch.randn_like(z)
        u = u / torch.norm(u)  # *np.sqrt(d+1)
        Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
        Pzv_lower = projection.projection_l2_ball(z - delta * u, max_value=rw)
        v = d / (2 * delta) * (
                torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
    G = torch.sum(Gy_gradients * v)
    Gyx_gradient = torch.autograd.grad(G, model.augment_parameters())
    Gyx_gradients = torch.unsqueeze(torch.cat([torch.reshape(gyx_g, [-1]) for gyx_g in Gyx_gradient]), 1).detach()
    outer_update = vx - eta * Q * Gyx_gradients
    torch.cuda.empty_cache()
    return outer_update


def TTSA_nas(model, train_x, train_y, val_x, val_y, criterion, eta, Q=3,
             rw=0.1, delta=1e-8):
    torch.cuda.empty_cache()
    params = torch.unsqueeze(torch.cat([torch.reshape(param, [-1]) for param in model.parameters()]), 1)
    d = len(params)
    f = model._loss(val_x, val_y)
    Fy_gradient = torch.autograd.grad(f, model.parameters(),retain_graph=True)
    Fx_gradient = torch.autograd.grad(f, model.arch_parameters(), retain_graph=True)
    v = torch.unsqueeze(torch.cat([torch.reshape(fy_g, [-1]) for fy_g in Fy_gradient]), 1).detach()
    vx = torch.unsqueeze(torch.cat([torch.reshape(fx_g, [-1]) for fx_g in Fx_gradient]), 1).detach()

    cQ = np.random.randint(Q)
    for i in range(cQ):
        logits = model(train_x)
        g = criterion(logits, train_y)
        Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
        Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
        z = params - eta * Gy_gradients
        with torch.no_grad():
            u = torch.randn_like(z)
            u = u / torch.norm(u)  # *np.sqrt(d+1)
            b = torch.norm(z)
            Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
            a = torch.norm(Pzv_upper)
            Pzv_lower = projection.projection_l2_ball(z - delta * u, max_value=rw)
            v = d / (2 * delta) * (
                    torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
        G = torch.sum(z * v)
        GYG = torch.autograd.grad(G, model.parameters())
        v = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in GYG]), 1)
        # v = torch.unsqueeze(torch.reshape(torch.autograd.grad(G, model.parameters())[0], [-1]), 1)
    logits = model(train_x)
    g = criterion(logits, train_y)
    Gy_gradient = torch.autograd.grad(g, model.parameters(), create_graph=True)
    Gy_gradients = torch.unsqueeze(torch.cat([torch.reshape(gy, [-1]) for gy in Gy_gradient]), 1)
    z = params - eta * Gy_gradients
    with torch.no_grad():
        u = torch.randn_like(z)
        u = u / torch.norm(u)  # *np.sqrt(d+1)
        b = torch.norm(z)
        Pzv_upper = projection.projection_l2_ball(z + delta * u, max_value=rw)
        a = torch.norm(Pzv_upper)
        Pzv_lower = projection.projection_l2_ball(z - delta * u, max_value=rw)
        v = d / (2 * delta) * (
                torch.sum(Pzv_upper * v) - torch.sum(Pzv_lower * v)) * u
    G = torch.sum(Gy_gradients * v)
    Gyx_gradient = torch.autograd.grad(G, model.arch_parameters())
    Gyx_gradients = torch.unsqueeze(torch.cat([torch.reshape(gyx_g, [-1]) for gyx_g in Gyx_gradient]), 1).detach()
    outer_update = vx - eta * Q * Gyx_gradients
    torch.cuda.empty_cache()
    return outer_update