import datetime
from typing import List, Optional

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.autograd import Variable

from carla import log
from carla.recourse_methods.processing import reconstruct_encoding_constraints

from methods.reup.chebysev import chebysev_center, sdp_cost
from methods.reup.q_determine import exhaustive_search

DECISION_THRESHOLD = 0.5


def gd(
    torch_model,
    x: np.ndarray,
    cat_feature_indices: List[int],
    binary_cat_features: bool,
    lr: float,
    lambda_param: float,
    n_iter: int,
    t_max_min: float,
    clamp: bool,
) -> np.ndarray:
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(0)

    torch_model.to(device)
    lamb = torch.tensor(lambda_param).float().to(device)

    dim = x.shape[1]
    x = torch.from_numpy(x).float().to(device)
    x_new = Variable(x.clone().reshape(1, -1), requires_grad=True)
    x_new_enc = reconstruct_encoding_constraints(
                x_new, cat_feature_indices, binary_cat_features
    )

    y_target = [1]
    y_target = torch.tensor(y_target).float().to(device)

    optimizer = optim.Adam([x_new], lr, amsgrad=True)
    loss_fn = torch.nn.BCELoss()

    f_x = torch_model(x_new_enc).reshape(-1)

    t0 = datetime.datetime.now()
    t_max = datetime.timedelta(minutes=t_max_min)

    while f_x <= DECISION_THRESHOLD:
        it = 0
        while f_x <= 0.5 and it < n_iter:
            optimizer.zero_grad()
            
            iden_matrix = torch.eye(dim)
            cost = (x_new_enc.reshape(-1, 1) - x.reshape(-1, 1)).T @ iden_matrix @ (x_new_enc.reshape(-1, 1) - x.reshape(-1, 1))
            f_loss = loss_fn(f_x, y_target)
            loss = f_loss + lamb * cost
            loss.backward()
            optimizer.step()

            if clamp:
                x_new.clone().clamp_(0, 1)

            x_new_enc = reconstruct_encoding_constraints(
                x_new, cat_feature_indices, binary_cat_features
            )
            f_x = torch_model(x_new_enc).reshape(-1)

            it += 1
        lamb *= 0.5

        if datetime.datetime.now() - t0 > t_max:
            log.info("Timeout - No Counterfactual Explanation Found")
            break
        elif f_x >= 0.5:
            log.info("Counterfactual Explanation Found")
    
    feasible = True if torch_model.predict(x_new_enc) == 1 else False

    return x_new_enc.cpu().detach().numpy().squeeze(axis=0), feasible