import torch
import torch.nn as nn
from torch.nn.functional import relu
import numpy as np

import exp_utils as PQ
import rl_utils


class FLAGS(PQ.BaseFLAGS):
    class temperature(PQ.BaseFLAGS):
        max = 0.1
        min = 0.001

    class filter(PQ.BaseFLAGS):
        top_k = 10000
        pool = False

    n_steps = 1
    method = 'grad'
    lr = 0.01
    batch_size = 1000
    extend_region = 0.0
    barrier_coef = 0.
    L_neg_coef = 1
    resample = False

    n_proj_iters = 10


class SLangevinOptimizer(nn.Module):
    FLAGS = FLAGS

    def __init__(self, crabs, nmPQ):  # precond: a diagonal matrix (covariance)
        super().__init__()
        self.crabs = crabs
        self.temperature = FLAGS.temperature.max
        self.nmPQ = nmPQ

        rand = torch.randn(FLAGS.batch_size, *nmPQ.std.shape, device=nmPQ.mean.device)
        s_init = nmPQ.mean + rand * nmPQ.std
        self.s = nn.Parameter(s_init, requires_grad=True)
        self.tau = nn.Parameter(torch.full([FLAGS.batch_size], 1e-2), requires_grad=False)
        self.opt = torch.optim.Adam([self.s])

        self.mask = torch.tensor([0], dtype=torch.int64)
        self.n_failure = torch.zeros(FLAGS.batch_size, dtype=torch.int64, device=nmPQ.mean.device)
        self.n_resampled = 0

        self.adam = torch.optim.Adam([self.s], betas=(0, 0.999), lr=0.001)
        device = 'cuda'
        self.last_info = torch.tensor(0., device=device), torch.tensor(0., device=device), torch.tensor(0., device=device)

    def reinit(self):
        nn.init.normal_(self.s)
        nn.init.constant_(self.tau, 0.01)

    def set_temperature(self, p):
        max = FLAGS.temperature.max
        min = FLAGS.temperature.min
        self.temperature = np.exp(np.log(max) * (1 - p) + np.log(min) * p)

    @property
    def filtered_s(self):
        if FLAGS.filter.pool:
            return self.pool.s
        return self.s

    def pdf(self, s):
        result = self.crabs.obj_eval(s)
        return result['hard_obj'] / self.temperature, result

    def project_back(self, should_print=False):
        for i in range(FLAGS.n_proj_iters):
            with torch.enable_grad():
                L = self.crabs.barrier(self.s)
                loss = relu(L)
                if L.sum() < 1000:
                    PQ.meters['opt_s/projection'] += i
                    break
                self.adam.zero_grad()
                loss.sum().backward()
                self.adam.step()
        else:
            PQ.meters['opt_s/projection'] += FLAGS.n_proj_iters

    @torch.no_grad()
    def resample(self, f: torch.Tensor, idx):
        if len(idx) == 0:
            return
        new_idx = f.softmax(0).multinomial(len(idx), replacement=True)
        self.s[idx] = self.s[new_idx]
        self.tau[idx] = self.tau[new_idx]
        self.n_failure[idx] = 0
        self.n_resampled += len(idx)

    def step(self):
        self.project_back()

        a, f_a, log_p_a2b = self.last_info
        tau = self.tau
        b = self.s

        f_b, b_info = self.pdf(b)
        grad_b = torch.autograd.grad(f_b.sum(), b)[0]

        z = torch.randn_like(b)
        center = b + tau[:, None] * grad_b
        c = center + (tau[:, None] * 2).sqrt() * z

        # use last information to update tau
        with torch.no_grad():
            log_p_b2a = -((a - center)**2).sum(dim=-1) / tau / 4
            log_ratio = (f_b + log_p_b2a) - (f_a + log_p_a2b)
            ratio = log_ratio.clamp(max=0).exp()
            self.tau.mul_(FLAGS.lr * (ratio - 0.574) + 1)  # .clamp_(max=1.0)
            PQ.meters['opt_s/accept'] += ratio.mean().item()

            log_p_b2c = -z.norm(dim=-1)**2
            self.last_info = b.detach().clone(), f_b.detach(), log_p_b2c.detach()

            c_out = self.crabs.barrier(c) > 0
            c[c_out] = b[c_out]
            self.s.set_(c)

        return {
            'optimal': b_info['hard_obj'].max().item(),
        }

    @torch.no_grad()
    def evaluate(self, *, step):
        result = self.crabs.obj_eval(self.s)
        L_s = result['L']
        hardD_s = result['hard_obj'].max().item()
        inside = (result['constraint'] <= 0).sum().item()
        cut_size = result['mask'].sum().item()

        geo_mean_tau = self.tau.log().mean().exp().item()
        max_tau = self.tau.max().item()
        PQ.writer.add_scalar('opt_s/hardD', hardD_s, global_step=step)
        PQ.writer.add_scalar('opt_s/inside', inside / FLAGS.batch_size, global_step=step)
        PQ.writer.add_scalar('opt_s/P_accept', PQ.meters['opt_s/accept'].mean, global_step=step)

        L_inside = L_s.cpu().numpy()
        L_inside = L_inside[np.where(result['constraint'].cpu() <= 0)]
        L_dist = np.percentile(L_inside, [25, 50, 75]) if len(L_inside) else []
        PQ.log.debug(f"[S Langevin]: temperature = {self.temperature:.3f}, hardD = {hardD_s:.6f}, "
                     f"inside/cut = {inside}/{cut_size}, "
                     f"n_proj = {PQ.meters['opt_s/projection'].mean:.3f}, "
                     f"tau = [geo mean {geo_mean_tau:.3e}, max {max_tau:.3e}], "
                     # f"Pr[out => in] = {PQ.meters['opt_s/out_to_in'].mean:.6f}, "
                     f"Pr[accept] = {PQ.meters['opt_s/accept'].mean:.3f}, "
                     # f"# valid = [s = {PQ.meters['opt_s/n_s_valid'].mean:.0f}, "
                     # f"pool = {PQ.meters['opt_s/n_pool_valid'].mean:.0f}], "
                     f"L 25/50/75% = {L_dist}, "
                     f"resampled = {self.n_resampled}")
        PQ.meters.purge('opt_s/')
        self.n_resampled = 0

        return {
            'inside': inside
        }
