import torch
import torch.nn as nn
from torch.nn.functional import softplus, relu
import numpy as np

import rl_utils
import exp_utils as PQ


class FLAGS(PQ.BaseFLAGS):
    weight_decay = 1e-4
    lr = 0.0003

    # obj_fn = ''
    # mask_fn = ''
    lambda_2 = 'norm'
    locals = {}


@torch.enable_grad()
def constrained_optimize(fx, gx, x, opt, reg=0.0):  # \grad_y [max_{x: g_y(x) <= 0} f(x)]
    sum_fx = fx.sum()
    sum_gx = gx.sum()
    with torch.no_grad():
        df_x = torch.autograd.grad(sum_fx, x, retain_graph=True)[0]
        dg_x = torch.autograd.grad(sum_gx, x, retain_graph=True)[0]
        lambda_ = df_x.norm(dim=-1) / dg_x.norm(dim=-1).clamp(min=1e-6)

    opt.zero_grad()
    (fx - gx * lambda_ + reg).sum().backward()
    opt.step()

    return {'df': df_x, 'dg': dg_x}


class LOptimizer(nn.Module):
    FLAGS = FLAGS

    def __init__(self, dim_state, obj_eval, params, L_ref):
        super().__init__()

        self.opt_params = torch.optim.Adam(params, lr=FLAGS.lr, weight_decay=FLAGS.weight_decay)
        self.dim_state = dim_state
        self.obj_eval = obj_eval
        self.L_ref = L_ref
    #
    # @torch.enable_grad()
    # def calc_lambda_2(self, s, lambda_2):
    #     s = s.clone().detach().requires_grad_()
    #     sum_U_s = self.U_pi(s).sum()
    #     sum_L_s = self.L(s).sum()
    #     with torch.no_grad():
    #         d_U_s = torch.autograd.grad(sum_U_s, s)[0]
    #         d_L_s = torch.autograd.grad(sum_L_s, s)[0]
    #
    #         lambdas = {
    #             'dot': lambda: (d_U_s * d_L_s).sum(dim=-1) / d_L_s.norm(dim=-1).pow(2).clamp(min=1e-6),
    #             'norm': lambda: d_U_s.norm(dim=-1) / d_L_s.norm(dim=-1).clamp(min=1e-6),
    #         }
    #
    #     if isinstance(lambda_2, float):
    #         score = lambda_2
    #     else:
    #         score = lambdas[lambda_2]()
    #     align = torch.cosine_similarity(d_U_s, d_L_s)
    #     return score, {'d_U_s': d_U_s, 'd_L_s': d_L_s, 'align': align}

    def step(self, s, should_update=True):
        self.obj_eval.train()
        s = s.detach().clone().requires_grad_()
        result = self.obj_eval(s)
        mask, obj = result['mask'], result['obj']
        regularization = 0
        # we want L(s) < L_ref(s)
        # regularization = result['L'].mean() * 0.00001
        if self.L_ref is not None:
            regularization = regularization + (result['L'] - self.L_ref(s)).clamp(min=0.).mean() * 0.001
        # regularization = 0.

        if mask.sum() > 0 and should_update:
            constrained_optimize(obj * mask / mask.sum(), result['constraint'], s, self.opt_params, reg=regularization)
            PQ.meters['opt_L/update_prob'] += 1
        else:
            PQ.meters['opt_L/update_prob'] += 0
        self.obj_eval.eval()
        return result

    @torch.no_grad()
    def evaluate(self, s, *, step):
        result = self.obj_eval(s)
        hardD = result['hard_obj']

        # lambda_2, infos = self.calc_lambda_2(s, FLAGS.lambda_2)
        # lambda_2 = torch.as_tensor(lambda_2)
        # cut_indices = ((L_s <= 1) & (U_s >= 1)).nonzero()[:, 0]
        # if len(cut_indices) and isinstance(FLAGS.lambda_2, str):
        #     lambda_2 = lambda_2[cut_indices]
        #     lambda_2_median = lambda_2.median().item()
        # elif isinstance(FLAGS.lambda_2, float):
        #     lambda_2_median = FLAGS.lambda_2
        # else:
        #     lambda_2_median = 0

        # align = infos['align']

        # if len(cut_indices) > 0:
        #     max_alignment = align[hardD.argmax()].item()
        #     max_cut_alignment = align[cut_indices].max().item()
        #     min_cut_alignment = align[cut_indices].min().item()
        # else:
        #     max_alignment, max_cut_alignment, min_cut_alignment = 0., 0., 0.

        PQ.log.debug(f"[L opt]: optimal = {hardD.max().item():.6f}, "
                     # f"cut lambda_2 median = {lambda_2_median:.6f}, "
                     # f"Pr[in => out] = {PQ.meters['opt_L/in_to_out'].mean:.6f}, "
                     f"E[# valid] = {PQ.meters['opt_L/n_valid'].mean:.0f}, "
                     f"Pr[update] = {PQ.meters['opt_L/update_prob'].mean:0f}"
                     # f"align = [optimal = {max_alignment:.6f}, "
                     # f"max = {max_cut_alignment:.6f}, min = {min_cut_alignment:.6f}]"
                     )
        PQ.meters.purge('opt_L/')
