import torch
import numpy as np


def operation_list(model, num_ops, num_blocks):
    """
    get current dominent operation list as one-hot
    should implement for target NAS backborn
    """
    ops_list = []
    for param in model.binary_gates():
        ops_list.append(param)
        if param.shape[0] != num_ops:
            ops_list.append(torch.Tensor([0] * (num_ops - param.shape[0])).to(param.device))
    operations = torch.cat(ops_list)
    operations = operations.reshape(-1, num_ops * num_blocks)

    return operations


def hw_metrics(model, evaluator, gs=True, print_pair=False):
    """
    if gs == True:
        Get a subnet which consists of current dominent operations.
    else:
        Sample a subnet by using architecture parameter as probability distributions.

    Get probability of above network's optimal accelerator architecture.
    Compute hardware metrics of given subnet and accelerator architecture pair.
    """
    ops_list = operation_list(model, evaluator.num_ops, evaluator.num_blocks)
    hw_params = evaluator.generator(ops_list, eval_gumbel=gs)
    eval_input = torch.cat([ops_list, hw_params], dim=-1)
    latency, energy, area = evaluator.estimator(eval_input)[0]
    if print_pair:
        print(f'ops_list: {ops_list.data}, hw_params: {hw_params.data}')
    return latency, energy, area


def get_targets(model, evaluator, gs):
    """
    Get hardware metrics with current architecture parameters and generator 
    """
    evaluator.generator.eval()
    return hw_metrics(model, evaluator, gs=False)


def train_generator_no_const(model, evaluator, args):
    """
    Warm up generator without constraints.
    """
    evaluator.generator.train()
    evaluator.generator_opt.zero_grad()

    ops_cat = []

    for layer in model.redundant_modules:
        probs = layer.probs_over_ops.clone().detach()

        one_hot = torch.zeros(args.generator_batch_size, args.num_ops)
        indices = probs.multinomial(args.generator_batch_size, replacement=True)
        one_hot[range(args.generator_batch_size), indices] = 1.
        ops_cat.append(one_hot)

    ops_cat = torch.cat(ops_cat, dim=-1).to(next(evaluator.estimator.parameters()).device)

    hw_params = evaluator.generator(ops_cat)
    eval_input = torch.cat([ops_cat, hw_params], dim=-1)
    output = evaluator.estimator(eval_input)

    if args.except_const:
        hw_loss = 0
        if args.latency_constrained_value == np.inf:
            hw_loss += output[:, 0].mean() * args.lambdas[0]
        if args.energy_constrained_value == np.inf:
            hw_loss += output[:, 1].mean() * args.lambdas[1]
        if args.area_constrained_value == np.inf:
            hw_loss += output[:, 2].mean() * args.lambdas[2]
        hw_loss *= args.scale_value

    else:
        hw_loss = (output[:, 0].mean() * args.lambdas[0] \
                    + output[:, 1].mean() * args.lambdas[1] \
                    + output[:, 2].mean() * args.lambdas[2]) * args.scale_value
    hw_loss = hw_loss.sum()
    hw_loss.backward()

    evaluator.generator_opt.step()

    evaluator.generator.eval()


def train_generator_flexible_scaler(model, evaluator, args, warmup=False):
    """
    Train generator with flexible scale value in order to comply with constraints
    """
    evaluator.generator.train()
    evaluator.generator_opt.zero_grad()

    ops_cat = []

    for layer in model.redundant_modules:
        probs = layer.probs_over_ops.clone().detach()

        one_hot = torch.zeros(args.generator_batch_size, args.num_ops)
        indices = probs.multinomial(args.generator_batch_size, replacement=True)
        one_hot[range(args.generator_batch_size), indices] = 1.
        ops_cat.append(one_hot)

    ops_cat = torch.cat(ops_cat, dim=-1).to(next(evaluator.estimator.parameters()).device)

    hw_params = evaluator.generator(ops_cat)
    eval_input = torch.cat([ops_cat, hw_params], dim=-1)
    output = evaluator.estimator(eval_input)

    hw_loss = (output[:, 0].mean() * args.lambdas[0]) \
              + (output[:, 1].mean() * args.lambdas[1]) \
              + (output[:, 2].mean() * args. lambdas[2])
    hw_loss *= args.scale_value
    hw_loss.backward()
    evaluator.generator_opt.step()

    evaluator.generator.eval()
    target_latency, target_energy, target_area = hw_metrics(model, evaluator, gs=False)

    if not warmup:

        latency_in_const = target_latency < args.latency_constrained_value \
                            or sum(output[:, 0] >= args.latency_constrained_value) == 0
        energy_in_const = target_energy < args.energy_constrained_value \
                            or sum(output[:, 1] >= args.energy_constrained_value) == 0
        area_in_const = target_area < args.area_constrained_value \
                            or sum(output[:, 2] >= args.area_constrained_value) == 0

        if latency_in_const and energy_in_const and area_in_const:
            if args.in_const:
                args.scale_value /= args.p
            else:
                args.scale_value = args.origin_scaler
                args.in_const = True
        else:
            if args.in_const:
                args.scale_value = args.origin_scaler
                args.in_const = False
            else:
                args.scale_value *= args.p

    return target_latency, target_energy, target_area


def train_generator_slope(model, evaluator, args):
    """
    Train generator with soft constraint
    """
    evaluator.generator.train()
    evaluator.generator_opt.zero_grad()

    ops_cat = []

    for layer in model.redundant_modules:
        probs = layer.probs_over_ops.clone().detach()

        one_hot = torch.zeros(args.generator_batch_size, args.num_ops)
        indices = probs.multinomial(args.generator_batch_size, replacement=True)
        one_hot[range(args.generator_batch_size), indices] = 1.
        ops_cat.append(one_hot)

    ops_cat = torch.cat(ops_cat, dim=-1).to(next(evaluator.estimator.parameters()).device)

    hw_params = evaluator.generator(ops_cat)
    eval_input = torch.cat([ops_cat, hw_params], dim=-1)
    output = evaluator.estimator(eval_input)

    hw_loss = 0
    batch_size = len(output)

    if args.latency_constrained_value == np.inf:
        hw_loss += output[:, 0].mean() * args.lambdas[0]
    else:
        excess_const = output[:, 0] >= args.latency_constrained_value
        under_const = output[:, 0] < args.latency_constrained_value
        hw_loss += output[excess_const, 0].sum() * args.lambdas[0] * args.slope.tau / batch_size
        hw_loss += output[under_const, 0].sum() * args.lambdas[0] / batch_size

    if args.energy_constrained_value == np.inf:
        hw_loss += output[:, 1].mean() * args.lambdas[1]
    else:
        excess_const = output[:, 1] >= args.latency_constrained_value
        under_const = output[:, 1] < args.latency_constrained_value
        hw_loss += output[excess_const, 1].sum() * args.lambdas[1] * args.slope.tau / batch_size
        hw_loss += output[under_const, 1].sum() * args.lambdas[1] / batch_size

    if args.area_constrained_value == np.inf:
        hw_loss += output[:, 2].mean() * args.lambdas[2]
    else:
        excess_const = output[:, 2] >= args.latency_constrained_value
        under_const = output[:, 2] < args.latency_constrained_value
        hw_loss += output[excess_const, 2].sum() * args.lambdas[2] * args.slope.tau / batch_size
        hw_loss += output[under_const, 2].sum() * args.lambdas[2] / batch_size

    hw_loss *= args.scale_value
    hw_loss.backward()
    evaluator.generator_opt.step()

    evaluator.generator.eval()
    target_latency, target_energy, target_area = hw_metrics(model, evaluator, gs=False)

    return target_latency, target_energy, target_area


def update_generator_grads(model, evaluator, args, warmup=False):
    """
    Train generator with gradient manipulation
    """
    # train generator first
    evaluator.generator.eval()

    target_latency, target_energy, target_area = hw_metrics(model, evaluator, gs=False)

    evaluator.generator.train()
    evaluator.generator_opt.zero_grad()

    ops_cat = []

    # set arch_params one_hot
    for layer in model.redundant_modules:
        probs = layer.probs_over_ops.clone().detach()

        one_hot = torch.zeros(args.generator_batch_size, args.num_ops)
        indices = probs.multinomial(args.generator_batch_size, replacement=True)
        one_hot[range(args.generator_batch_size), indices] = 1.
        ops_cat.append(one_hot)
    ops_cat = torch.cat(ops_cat, dim=-1).to(next(evaluator.estimator.parameters()).device)

    hw_params = evaluator.generator(ops_cat)
    eval_input = torch.cat([ops_cat, hw_params], dim=-1)
    output = evaluator.estimator(eval_input)

    if args.except_const:
        hw_loss = 0
        if args.latency_constrained_value == np.inf:
            hw_loss += output[:, 0].mean() * args.lambdas[0]
        if args.energy_constrained_value == np.inf:
            hw_loss += output[:, 1].mean() * args.lambdas[1]
        if args.area_constrained_value == np.inf:
            hw_loss += output[:, 2].mean() * args.lambdas[2]
    else:
        hw_loss = (output[:, 0].mean() * args.lambdas[0]) \
                  + (output[:, 1].mean() * args.lambdas[1]) \
                  + (output[:, 2].mean() * args. lambdas[2])
    hw_loss *= args.scale_value

    if warmup:
        latency_in_const = sum(output[:, 0] >= args.latency_constrained_value) == 0
        energy_in_const = sum(output[:, 1] >= args.energy_constrained_value) == 0
        area_in_const = sum(output[:, 2] >= args.area_constrained_value) == 0

    else:
        latency_in_const = target_latency < args.latency_constrained_value \
                            or sum(output[:, 0] >= args.latency_constrained_value) == 0
        energy_in_const = target_energy < args.energy_constrained_value \
                            or sum(output[:, 1] >= args.energy_constrained_value) == 0
        area_in_const = target_area < args.area_constrained_value \
                            or sum(output[:, 2] >= args.area_constrained_value) == 0

    if latency_in_const and energy_in_const and area_in_const:
        hw_loss.backward()
        
    else:
        const_loss = torch.tensor(0.).to(output.device)

        # add hw loss term to const_loss when (target - constraint) is big
        if target_latency >= args.latency_constrained_value:
            if args.swish:
                const_loss += torch.nn.SiLU()(output[:, 0] - args.latency_constrained_value).mean()
            else:
                const_loss += torch.relu(output[:, 0] - args.latency_constrained_value).mean()
        if target_energy >= args.energy_constrained_value:
            if args.swish:
                const_loss += torch.nn.SiLU()(output[:, 1] - args.energy_constrained_value).mean()
            else:
                const_loss += torch.relu(output[:, 1] - args.energy_constrained_value).mean()
        if target_area >= args.area_constrained_value:
            if args.swish:
                const_loss += torch.nn.SiLU()(output[:, 2] - args.area_constrained_value).mean()
            else:
                const_loss += torch.relu(output[:, 2] - args.area_constrained_value).mean()

        if const_loss > 0:
            g_u = []
            g_f = []
        
            const_loss.backward(retain_graph=True)
            for tensor in evaluator.generator.parameters():
                if tensor.grad is not None:
                    g_f.append(tensor.grad.clone())

            evaluator.generator_opt.zero_grad()

            hw_loss.backward(retain_graph=True)

            for tensor in evaluator.generator.parameters():
                if tensor.grad is not None:
                    g_u.append(tensor.grad.clone())

            u_vector = []
            f_vector = []
        
            for tensor_u, tensor_f in zip(g_u, g_f):
                assert tensor_u.shape == tensor_f.shape, 'shape of binary gates should be same'
                u_vector.append(tensor_u.reshape(-1))
                f_vector.append(tensor_f.reshape(-1))

            u_vector = torch.cat(u_vector)
            f_vector = torch.cat(f_vector)

            direction = torch.dot(u_vector, f_vector)
            if direction >= 0:
                i = 0
                for tensor in evaluator.generator.parameters():
                    if tensor.grad is not None:
                        tensor.grad = g_u[i]
                        i += 1
            else:
                dot_product = 0
                for u, f in zip(u_vector, f_vector):
                    dot_product += ((u + f) * f).sum()

                v_coeff = (-dot_product + args.delta_generator) / (f_vector.norm() ** 2)

                i = 0
                for tensor in evaluator.generator.parameters():
                    if tensor.grad is not None:
                        # tensor.grad = g_u[i] + ((1 + v_coeff) * g_f[i])
                        tensor.grad = g_u[i] + (v_coeff * g_f[i])
                        i += 1
        else:
            hw_loss.backward()

    evaluator.generator_opt.step()

    if not warmup:
        evaluator.generator.eval()
        target_latency, target_energy, target_area = hw_metrics(model, evaluator, gs=False)
    
        if target_latency <= args.latency_constrained_value and target_energy <= args.energy_constrained_value and target_area <= args.area_constrained_value:
            if args.in_const:
                args.delta_generator /= args.p
            else:
                args.delta_generator = args.origin_delta_generator
                args.in_const = True
        else:
            if args.in_const:
                args.delta_generator = args.origin_delta_generator
                args.in_const = False
            else:
                args.delta_generator *= args.p

    return target_latency, target_energy, target_area


def train_arch_param_flexible_scaler(model, arch_params, ce_loss, targets, evaluator, args):
    """
    Train architecture parameter with flexible scale value in order to comply with constraints
    """
    evaluator.generator.eval()

    latency, energy, area = hw_metrics(model, evaluator)

    hw_loss = (latency * args.lambdas[0]) + (energy * args.lambdas[1]) + (area * args.lambdas[2])
    hw_loss *= args.scale_value
    model.zero_grad()

    loss = ce_loss + hw_loss
    loss.backward()

    if targets[0] <= args.latency_constrained_value and targets[1] <= args.energy_constrained_value and targets[2] <= args.area_constrained_value:
        if args.in_const:
            args.scale_value /= args.p
        else:
            args.scale_value = args.origin_scaler
            args.in_const = True
    else:
        if args.in_const:
            args.scale_value = args.origin_scaler
            args.in_const = False
        else:
            args.scale_value *= args.p

    return hw_loss


def train_arch_param_slope(model, ce_loss, targets, evaluator, args):
    """
    Train architecture parameter with soft constraint
    """
    evaluator.generator.eval()
    latency, energy, area = hw_metrics(model, evaluator)

    hw_loss = latency * args.lambdas[0] * (1 if targets[0] < args.latency_constrained_value else args.slope.tau)
    hw_loss += energy * args.lambdas[1] * (1 if targets[1] < args.energy_constrained_value else args.slope.tau)
    hw_loss += area * args.lambdas[2] * (1 if targets[2] < args.area_constrained_value else args.slope.tau)

    hw_loss *= args.scale_value

    model.zero_grad()
    loss = ce_loss + hw_loss
    loss.backward()

    return hw_loss


def update_arch_param_grads(model, arch_params, ce_loss, targets, evaluator, args):
    """
    Train architecture parameter with given loss function and hardware loss
    """
    evaluator.generator.eval()

    latency, energy, area = hw_metrics(model, evaluator)

    # train arch_params
    out_const = []

    if args.except_const:
        hw_loss = torch.tensor(0.).to(latency.device)
        if args.latency_constrained_value == np.inf:
            hw_loss += latency * args.lambdas[0]
        if args.energy_constrained_value == np.inf:
            hw_loss += energy * args.lambdas[1]
        if args.area_constrained_value == np.inf:
            hw_loss += area * args.lambdas[2]
    else:
        hw_loss = (latency * args.lambdas[0]) + (energy * args.lambdas[1]) + (area * args.lambdas[2])

    if targets[0] > args.latency_constrained_value:
        if args.swish:
            out_const.append(torch.nn.SiLU()(latency - args.latency_constrained_value))
        elif latency > args.latency_constrained_value:
            out_const.append(latency - args.latency_constrained_value)

    if targets[1] > args.energy_constrained_value:
        if args.swish:
            out_const.append(torch.nn.SiLU()(energy - args.energy_constrained_value))
        elif energy > args.energy_constrained_value:
            out_const.append(energy - args.energy_constrained_value)

    if targets[2] > args.area_constrained_value:
        if args.swish:
            out_const.append(torch.nn.SiLU()(area - args.area_constrained_value))
        elif area > args.area_constrained_value:
            out_const.append(area - args.area_constrained_value)

    hw_loss *= args.scale_value
        
    # set grad of arch_params zero
    model.zero_grad()

    if sum(out_const) == 0:
        loss = ce_loss + hw_loss
        loss.backward()

    else:
        loss = ce_loss + hw_loss
        loss.backward(retain_graph=True)

        g_u = []
        g_f = []

        for tensor in arch_params():
            if tensor.grad is not None:
                g_u.append(tensor.grad.clone())

        model.zero_grad()

        const_loss = sum(out_const)

        const_loss.backward(retain_graph=True)

        for tensor in arch_params():
            if tensor.grad is not None:
                g_f.append(tensor.grad.clone())

        u_vector = []
        f_vector = []
        for tensor_u, tensor_f in zip(g_u, g_f):
            assert tensor_u.shape == tensor_f.shape, 'shape of binary gates should be same'
            u_vector.append(tensor_u.reshape(-1))
            f_vector.append(tensor_f.reshape(-1))
        u_vector = torch.cat(u_vector)
        f_vector = torch.cat(f_vector)

        direction = torch.dot(u_vector, f_vector)

        if direction >= 0:
            i = 0
            for tensor in arch_params():
                if tensor.grad is not None:
                    tensor.grad = g_u[i]
                    i += 1
        else:
            dot_product = 0
            for u, f in zip(u_vector, f_vector):
                dot_product += ((u + f) * f).sum()

            v_coeff = ((-dot_product + args.delta_supernet) / (f_vector.norm() ** 2))

            i = 0
            for tensor in arch_params():
                if tensor.grad is not None:
                    tensor.grad = g_u[i] + (v_coeff * g_f[i])
                    i += 1
 
        if not args.no_const:
            if targets[0] <= args.latency_constrained_value and targets[1] <= args.energy_constrained_value and targets[2] <= args.area_constrained_value:
                if args.in_const:
                    args.delta_supernet /= args.p
                else:
                    args.delta_supernet = args.origin_delta_supernet
                    args.in_const = True
            else:
                if args.in_const:
                    args.delta_supernet = args.origin_delta_supernet
                    args.in_const = False
                else:
                    args.delta_supernet *= args.p

    return hw_loss

