import random
import numpy as np
import math
import PIL

import torch


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)


def cal_subtb_coef_matrix(lamda, N):
    """
    diff_matrix: (N+1, N+1)
    0, 1, 2, ...
    -1, 0, 1, ...
    -2, -1, 0, ...

    self.coef[i, j] = lamda^(j-i) / total_lambda  if i < j else 0.
    """
    range_vals = torch.arange(N + 1)
    diff_matrix = range_vals - range_vals.view(-1, 1)
    B = np.log(lamda) * diff_matrix
    B[diff_matrix <= 0] = -np.inf
    log_total_lambda = torch.logsumexp(B.view(-1), dim=0)
    coef = torch.exp(B - log_total_lambda)
    return coef


def logmeanexp(x, dim=0):
    return x.logsumexp(dim) - math.log(x.shape[dim])


def dcp(tensor):
    return tensor.detach().cpu()


def gaussian_params(tensor):
    mean, logvar = torch.chunk(tensor, 2, dim=-1)
    return mean, logvar


def uniform_discretizer(bsz, trajectory_length):
    return torch.linspace(0, 1, trajectory_length + 1).repeat(bsz, 1)


def harmonic_discretizer(bsz, trajectory_length):
    step_sizes = 1 / torch.arange(1, trajectory_length + 1)
    sum_step_sizes = torch.sum(step_sizes)
    step_proportions = step_sizes / sum_step_sizes
    split_points = torch.cumsum(step_proportions, dim=0)
    return torch.cat([torch.tensor([0]), split_points]).repeat(bsz, 1)


def sqrt_harmonic_discretizer(bsz, trajectory_length):
    step_sizes = 1 / torch.sqrt(torch.arange(1, trajectory_length + 1))
    sum_step_sizes = torch.sum(step_sizes)
    step_proportions = step_sizes / sum_step_sizes
    split_points = torch.cumsum(step_proportions, dim=0)
    return torch.cat([torch.tensor([0]), split_points]).repeat(bsz, 1)


def first_big_step_discretizer(bsz, trajectory_length, big_step_ratio=0.9):
    return torch.cat([torch.tensor([0]), torch.linspace(big_step_ratio, 1, trajectory_length)]).repeat(bsz, 1)


def random_discretizer(bsz, trajectory_length, max_ratio):
    x = (torch.rand(bsz, trajectory_length) * (max_ratio - 1) + 1).cumsum(1)
    x = torch.cat([torch.zeros(bsz, 1), x], 1) / x[:, -1].unsqueeze(1)
    return x


# def low_discrepancy_discretizer(bsz, traj_length=2):
#     u = torch.rand(1, traj_length).cumsum(1)
#     shift_vector = (torch.arange(bsz) / bsz).unsqueeze(1).repeat(1, traj_length-1)
#     u = (u/u[:, -1])[:, :-1]
#     timestep = u + shift_vector
#     timesteps_in_range = timestep % 1.0
#     timesteps_sorted, indices = torch.sort(timesteps_in_range, dim=-1, descending=False)
#     x = torch.cat([torch.zeros(bsz, 1), timesteps_sorted, torch.ones(bsz, 1)], dim=1)
#     return x


def low_discrepancy_discretizer(bsz, traj_length=2):
    u = torch.rand(1, traj_length - 1)
    u_sorted, _ = torch.sort(u, dim=-1, descending=False)
    # print(u_sorted)
    # print(u_sorted.shape)
    shift_vector = (torch.arange(bsz) / bsz).unsqueeze(1).repeat(1, traj_length - 1)
    timestep = u + shift_vector
    timesteps_in_range = timestep % 1.0
    timesteps_sorted, indices = torch.sort(timesteps_in_range, dim=-1, descending=False)
    x = torch.cat([torch.zeros(bsz, 1), timesteps_sorted, torch.ones(bsz, 1)], dim=1)
    return x

    # old code below:
    # u = torch.rand(1)
    # shift_vector = torch.arange(bsz)/bsz
    # timestep = u + shift_vector
    # timestep_in_range = timestep % 1.0
    # timestep_in_range = timestep_in_range.unsqueeze(-1)
    # x = torch.cat([torch.zeros(bsz, 1), timestep_in_range, torch.ones(bsz, 1)], 1)
    # return x


def low_discrepancy_discretizer2(bsz, traj_length=2):
    s = traj_length - 1
    u = torch.rand(1, s)
    shift_vector = torch.arange(bsz) / bsz
    timestep = u + shift_vector.unsqueeze(-1)
    timestep_in_range = timestep % 1.0
    x = (timestep_in_range + torch.arange(s).unsqueeze(0)) / s
    x = torch.stack([col[torch.randperm(col.size(0))] for col in x.t()]).t()
    return x


class CustomGFNOptimizer:

    def __init__(
        self,
        gfn_model,
        lr_policy,
        lr_flow,
        lr_back_multiplier,
        learn_pb=False,
        conditional_flow_model=False,
        use_weight_decay=False,
        weight_decay=1e-7,
        use_2optimizers=False,
        share_backbone=False,
        gamma=0.9999,
    ):
        self.use_2optimizers = use_2optimizers

        if not use_weight_decay:
            weight_decay = 0

        self.gamma = gamma
        if use_2optimizers:
            assert learn_pb, "if use_2optimizers is True, then learn_pb must be enabled."
            pf_param_groups = [
                {"params": gfn_model.t_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.s_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.joint_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.langevin_scaling_model.parameters(), "lr": lr_policy},
            ]
            if conditional_flow_model:
                pf_param_groups += [{"params": gfn_model.flow_model.parameters(), "lr": lr_flow}]
            else:
                pf_param_groups += [{"params": [gfn_model.flow_model], "lr": lr_flow}]

            self.pf_optimizer = torch.optim.Adam(pf_param_groups, weight_decay=weight_decay)
            self.pf_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.pf_optimizer, gamma=gamma)

            pb_param_groups = [
                {"params": gfn_model.t_model.parameters(), "lr": lr_policy * lr_back_multiplier},
                {"params": gfn_model.s_model.parameters(), "lr": lr_policy * lr_back_multiplier},
                {"params": gfn_model.langevin_scaling_model.parameters(), "lr": lr_policy * lr_back_multiplier},
            ]
            pb_param_groups += [{"params": gfn_model.back_model.parameters(), "lr": lr_policy * lr_back_multiplier}]
            self.pb_optimizer = torch.optim.Adam(pb_param_groups, weight_decay=weight_decay)
            self.pb_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.pb_optimizer, gamma=gamma)
        else:

            param_groups = [
                {"params": gfn_model.t_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.s_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.joint_model.parameters(), "lr": lr_policy},
                {"params": gfn_model.langevin_scaling_model.parameters(), "lr": lr_policy},
            ]

            if learn_pb:
                if share_backbone:
                    param_groups += [{"params": gfn_model.back_model.last_layer.parameters(), "lr": lr_policy * lr_back_multiplier}]
                else:
                    param_groups += [{"params": gfn_model.back_model.parameters(), "lr": lr_policy * lr_back_multiplier}]

            if conditional_flow_model:
                param_groups += [{"params": gfn_model.flow_model.parameters(), "lr": lr_flow}]
            else:
                param_groups += [{"params": [gfn_model.flow_model], "lr": lr_flow}]

            self.optimizer = torch.optim.Adam(param_groups, weight_decay=weight_decay)
            self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=gamma)

    def step(self):
        if self.use_2optimizers:
            self.pf_optimizer.step()
            self.pb_optimizer.step()
        else:
            self.optimizer.step()

    def get_last_lr(self):
        if self.use_2optimizers:
            return self.pf_lr_scheduler.get_last_lr()[0], self.pb_lr_scheduler.get_last_lr()[0]
        else:
            return self.lr_scheduler.get_last_lr()[0]

    def lr_scheduler_step(self):
        if self.use_2optimizers:
            self.pf_lr_scheduler.step()
            self.pb_lr_scheduler.step()
        else:
            self.lr_scheduler.step()


def update_target_network(gfn_model, target_gfn_model, tau):
    with torch.no_grad():
        for param, target_param in zip(gfn_model.parameters(), target_gfn_model.parameters()):
            target_param.data.mul_(1 - tau)
            torch.add(target_param.data, param.data, alpha=tau, out=target_param.data)


def get_gfn_loss(
    is_fwd,
    learn_pb,
    pf_loss_fn,
    pb_loss_fn,
    start_states,
    gfn_model,
    target_gfn_model,
    energy,
    coef_matrix,
    discretizer_fn,
    target_pf=False,
    target_pb=False,
    exploration_std=None,
    return_exp=False,
    states=None,
    huber_loss_quantile=1,
    log_r=None,
    pis=False,
):
    if states is not None:
        _, log_pfs, log_pbs, log_fs = gfn_model.get_trajectory_fwd(None, discretizer_fn, exploration_std, energy.log_reward, states, pis)
    else:
        if is_fwd:
            states, log_pfs, log_pbs, log_fs = gfn_model.get_trajectory_fwd(
                start_states, discretizer_fn, exploration_std, energy.log_reward, None, pis
            )
        else:
            states, log_pfs, log_pbs, log_fs = gfn_model.get_trajectory_bwd(start_states, discretizer_fn, energy.log_reward)
    if log_r is None:
        if not pis:
            with torch.no_grad():
                if isinstance(states, list):
                    log_r = energy.log_reward([states[0][:, -1], states[1]])
                else:
                    log_r = energy.log_reward(states[:, -1])
            log_r = log_r.detach()
        else:
            with torch.enable_grad():
                log_r = energy.log_reward(states[:, -1])
    if target_pf or target_pb:
        _, target_log_pfs, target_log_pbs, target_log_fs = target_gfn_model.get_trajectory_fwd(
            None, discretizer_fn, exploration_std, energy.log_reward, states
        )

    if not pis:
        if target_pb:
            pf_loss = pf_loss_fn(log_pfs, target_log_pbs.detach(), log_fs, log_r, energy, states, coef_matrix, huber_loss_quantile)
        else:
            # pf_loss = pf_loss_fn(log_pfs, log_pbs.detach(), log_fs, log_r, energy, states, coef_matrix, huber_loss_quantile)
            pf_loss = pf_loss_fn(log_pfs, log_pbs.detach(), log_fs, log_r, energy, states, coef_matrix, huber_loss_quantile)

        if target_pf:
            pb_loss = pb_loss_fn(
                target_log_pfs.detach(),
                log_pbs,
                target_log_fs.detach(),
                log_r.detach(),
                energy,
                states,
                coef_matrix,
                huber_loss_quantile,
            )
        else:
            pb_loss = pb_loss_fn(
                log_pfs.detach(),
                log_pbs,
                log_fs.detach(),
                log_r,
                energy,
                states,
                coef_matrix,
                huber_loss_quantile,
            )
    else:
        _, not_pis_log_pfs, not_pis_log_pbs, not_pis_log_fs = gfn_model.get_trajectory_fwd(
            start_states, discretizer_fn, exploration_std, energy.log_reward, None
        )

        if target_pb:
            pf_loss = pf_loss_fn(log_pfs, target_log_pbs, log_fs, log_r, energy, states, coef_matrix, huber_loss_quantile)
        else:
            pf_loss = pf_loss_fn(log_pfs, log_pbs, log_fs, log_r, energy, states, coef_matrix, huber_loss_quantile)

        with torch.no_grad():
            no_grad_log_r = energy.log_reward(states[:, -1]).detach()

        if target_pf:
            pb_loss = pb_loss_fn(
                target_log_pfs.detach(),
                not_pis_log_pbs,
                target_log_fs.detach(),
                no_grad_log_r,
                energy,
                states,
                coef_matrix,
                huber_loss_quantile,
            )
        else:
            pb_loss = pb_loss_fn(
                not_pis_log_pfs.detach(),
                not_pis_log_pbs,
                not_pis_log_fs.detach(),
                no_grad_log_r,
                energy,
                states,
                coef_matrix,
                huber_loss_quantile,
            )

    if return_exp:
        return pf_loss, pb_loss, states, log_pfs, log_pbs, log_fs, log_r
    else:
        return pf_loss, pb_loss


def get_exploration_std(iter, exploratory, exploration_factor=0.1, exploration_wd=False):
    if exploratory is False:
        return None
    if exploration_wd:
        exploration_std = exploration_factor * max(0, 1.0 - iter / 10000.0)
    else:
        exploration_std = exploration_factor
    expl = lambda x: exploration_std
    return expl


def get_name(args):
    name = f"T_{args.T}/RRN_{args.replay_ratio_n}/"

    name += f"{args.discretizer}/"
    if args.discretizer == "random":
        name += f"max_ratio_{args.discretizer_max_ratio}/"
    if args.traj_length_strategy == "dynamic":
        name += f"dynamic_{args.min_traj_length}_{args.max_traj_length}/"

    if args.langevin:
        if args.langevin_scaling_per_dimension:
            name += f"langevin_scaling_per_dimension_"
        else:
            name += f"langevin_"
    if args.exploratory and (args.exploration_factor is not None):
        if args.exploration_wd:
            name += f"exploration_wd_{args.exploration_factor}"
        else:
            name += f"exploration_{args.exploration_factor}"

    # if args.pf_mode_fwd == "subtb":
    #     mode_fwd = f"subtb_subtb_lambda_{args.subtb_lambda}"
    #     if args.partial_energy:
    #         mode_fwd = f"{mode_fwd}_{args.partial_energy}"
    # else:
    #     mode_fwd = args.mode_fwd

    if args.both_ways:
        ways = f"fwd_bwd/fwd_{args.pf_mode_fwd}_{args.pb_mode_fwd}_bwd_{args.pf_mode_bwd}_{args.pb_mode_bwd}"
    elif args.bwd:
        ways = f"bwd/bwd_{args.pf_mode_bwd}_{args.pb_mode_bwd}"
    else:
        ways = f"fwd/fwd_{args.pf_mode_fwd}_{args.pb_mode_fwd}"

    if args.local_search:
        ways += (
            f"/local_search_iter_{args.max_iter_ls}_burn_{args.burn_in}_cycle_{args.ls_cycle}_step_{args.ld_step}_"
            f"beta_{args.beta}_rankw_{args.rank_weight}_prioritized_{args.prioritized}"
        )

    if args.pis_architectures:
        results = "results_pis_architectures"
    else:
        results = "results"

    auxiliary_energy_name = ""
    if "gan" in args.energy:
        auxiliary_energy_name += f"_mc={args.gan_magic_const}_prompt={args.gan_prompt.replace(' ', '_').lower()}"
    if args.energy == "gan_cifar10":
        auxiliary_energy_name += "_dcgan" if args.is_dcgan else "_sngan"
    if "distorted" in args.energy:
        distortion_coef = f"_dc={args.distortion_coef}"
    else:
        distortion_coef = ""
    data_dim = f"_dim={args.data_dim}"
    # if args.energy in ["40gmm", "distorted_gmm", "distorted_many_well"]:
    # else:
    #     data_dim = ""

    name = (
        f"{results}-v{args.version}/{args.energy}{auxiliary_energy_name}{distortion_coef}{data_dim}/"
        f"{name}gfn/{ways}/tscale_{args.t_scale}/lvr_{args.log_var_range}/gamma_{args.gamma}/"
        f"tau={args.tau}_target_pf={args.target_pf}_target_pb={args.target_pb}/"
        f"huber_q={args.huber_loss_quantile}_clip_grad_norm={args.clip_grad_norm}_clip_grad_q={args.clip_grad_quantile}/"
    )
    if args.learn_pb:
        name += f"lr_back_multiplier={args.lr_back_multiplier}"
        name += f"_learn_pb_{args.process_param}_scale_range_{args.pb_scale_range}_scale_policy_{args.pb_scale_policy}"
        name += "_2optimizers" if args.use_2optimizers else "_1optimizer"
        name += "_shared_backbone" if args.share_backbone else "_separate_backbones"
    else:
        name += "fixed_pb"
    name += "_learned_variance" if args.learned_variance else "_fixed_variance"

    name = f"{name}/{'zero_init' if args.zero_init else 'rand_init'}/{'perfect_' if args.perfect else ''}seed_{args.seed}/"

    return name


def get_wandb_name(args):
    # if args.langevin:
    #     name = f"langevin_"
    #     if args.langevin_scaling_per_dimension:
    #         name = f"langevin_scaling_per_dimension_"
    # if args.exploratory and (args.exploration_factor is not None):
    #     if args.exploration_wd:
    #         name = f"exploration_wd_{args.exploration_factor}_{name}_"
    #     else:
    #         name = f"exploration_{args.exploration_factor}_{name}_"
    if args.perfect:
        return "perfect"

    name = ""
    if args.learn_pb:
        name = f"{name}_learn_pb_{args.process_param}_scale_range_{args.pb_scale_range}_gamma_{args.gamma}_"

        if args.use_2optimizers:
            name += "use_2optimizers_"
        if args.share_backbone:
            name += "share_backbone_"

    name = f"{name}/seed_{args.seed}"

    return name
