import torch
import torch.nn as nn
#from utils.prox import flatten_quant_weights_and_grads, compute_midpoints
#from logic import log_not, log_and


def log_and(*conds):
    # Given Boolean expressions, checks if all expressions are true

    if len(conds) < 2:
        raise ValueError('Need at least two conditions here.')
    
    bool_tensor = torch.logical_and(conds[0], conds[1])
    for cond in conds[2:]:
        bool_tensor = torch.logical_and(bool_tensor, cond)

    return bool_tensor

def log_not(a):
    return torch.logical_not(a)

class Prox(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def set_weights(self, weights):
        pass

    def set_cutoffs(self, cutoffs):
        pass

    def set_optimizer(self):
        pass

    def set_scheduler(self):
        pass

    def scheduler_step(self):
        pass

    def set_lambda(self, lambda_in):
        pass

    def update(self, model):
        pass

    def freeze_weights(self):
        pass

    def forward(self, x):
        raise NotImplementeError


class Identity(Prox):
    def forward(self, x):
        return x


class FixedProx(Prox):
    def __init__(self, method, *args, **kwargs):
        super().__init__()

        self._method = method

    def set_weights(self, weights):
        self._weights = nn.Parameter(weights)

    def set_cutoffs(self, cutoffs):
        self._cutoffs = cutoffs


class ConstantFixed(FixedProx):
    def forward(self, x):
        for i, c in enumerate(self._cutoffs):
            print("i:{}".format(i))
            print("c:{}".format(c))
            if i == 0:
                x = torch.where(x < c, self._weights[0], x)
                print(x)
            else:
                x = torch.where(log_and(x >= self._cutoffs[i - 1], x < c), self._weights[i], x)
                print(x)

        x = torch.where(x > self._cutoffs[-1], self._weights[-1], x)
        return x


class LinearFixed(FixedProx):
    def set_lambda(self, lambda_in):
        self._lambda = lambda_in

    def forward(self, x):
        if self.training:
            for i, c in enumerate(self._cutoffs):
                w = self._weights[i]
                cond1 = x < w - self._lambda
                cond2 = log_not(cond1)
                cond3 = x <= w + self._lambda
                cond4 = log_not(cond3)
                
                if i == 0:
                    x = torch.where(cond1, x + self._lambda, x)
                    x = torch.where(log_and(cond2, cond3, x <= c), w, x)
                    x = torch.where(log_and(cond4, x <= c), x - self._lambda, x)
                
                else:
                    c_left = self._cutoffs[i - 1]
                    x = torch.where(log_and(x > c_left, cond1), x + self._lambda, x)
                    x = torch.where(log_and(x > c_left, cond2, cond3, x <= c), w, x)
                    x = torch.where(log_and(cond4, x <= c), x - self._lambda, x)

            w = self._weights[-1]
            cond1 = x < w - self._lambda
            cond2 = log_not(cond1)
            cond3 = x <= w + self._lambda
            cond4 = log_not(cond3)
            
            c_left = self._cutoffs[-1]
            x = torch.where(log_and(x > c_left, cond1), x + self._lambda, x)
            x = torch.where(log_and(x > c_left, cond2, cond3), w, x)
            x = torch.where(cond4, x - self._lambda, x)
            return x
        else:
            for i, c in enumerate(self._cutoffs):
                if i == 0:
                    x = torch.where(x < c, self._weights[0], x)
                else:
                    x = torch.where(log_and(x >= self._cutoffs[i - 1], x < c), self._weights[i], x)

            x = torch.where(x > self._cutoffs[-1], self._weights[-1], x)
            return x



class LearnableProx(Prox):
    def __init__(self, method, optim_cls, optim_params, scheduler_cls, scheduler_params, learning_threshold=1e-6):
        super().__init__()

        self._method = method
        self._optim_cls = optim_cls
        self._optim_params = optim_params
        self._scheduler_cls = scheduler_cls
        self._scheduler_params = scheduler_params
        self._learning_threshold = learning_threshold

    def set_weights(self, weights):
        self._init_weights = weights.clone().detach()
        self._weights = nn.Parameter(weights)
        self._weights.requires_grad = True

    def set_cutoffs(self, cutoffs):
        self._init_cutoffs = torch.tensor([-1.5] + cutoffs.tolist() + [1.5],
                                          device=cutoffs.device)
        self._cutoffs = cutoffs

    def freeze_weights(self):
        self._weights.requires_grad = False
        self._weights.grad = None

    def set_optimizer(self):
        trainable_params = filter(lambda p: p.requires_grad, [self._weights])
        self._optim = self._optim_cls(trainable_params, **self._optim_params)

    def set_scheduler(self):
        if self._scheduler_cls is not None:
            self._scheduler = self._scheduler_cls(self._optim, **self._scheduler_params)
        else:
            self._scheduler = None

    def scheduler_step(self):
        if self._scheduler is not None:
            self._scheduler.step()

    def _clip_weights(self):
        self._weights.data = torch.minimum(torch.maximum(self._weights.data, self._init_cutoffs[:-1] + 1e-4), self._init_cutoffs[1:] - 1e-4)

    def _compute_grad(self, model):
        weight_list, grad_list = flatten_quant_weights_and_grads(self._method, model, self)

        self._weights.grad = torch.zeros_like(self._weights)
        for i, w in enumerate(self._weights):
            
            n_total_involved = 0
            for j, wl in enumerate(weight_list):
                index = torch.abs(wl - w) < self._learning_threshold

                self._weights.grad[i] += torch.sum(grad_list[j][index])
                n_total_involved += torch.sum(index == True)
            
            if n_total_involved > 0:
                self._weights.grad[i] /= n_total_involved


class LearnableMovingCutoffs(LearnableProx):
    def _compute_cutoffs(self):
        self._cutoffs = compute_midpoints(self._weights)

    def _map_weights(self):
        self._weights.data = torch.sort(self._weights.data)[0]

    def update(self, model):
        if self._weights.requires_grad:
            self._compute_grad(model)
            self._optim.step()
            #self._map_weights()
            self._clip_weights()
            self._compute_cutoffs()
    

class ConstantLearnableMovingCutoffs(LearnableMovingCutoffs):
    def forward(self, x):
        for i, c in enumerate(self._cutoffs):
            if i == 0:
                x = torch.where(x < c, self._weights[0], x)
            else:
                x = torch.where(log_and(x >= self._cutoffs[i - 1], x < c), self._weights[i], x)

        x = torch.where(x > self._cutoffs[-1], self._weights[-1], x)
        return x

class LinearLearnableMovingCutoffs(LearnableMovingCutoffs):
    def set_lambda(self, lambda_in):
        self._lambda = lambda_in

    def forward(self, x):
        if self.training:
            for i, c in enumerate(self._cutoffs):
                w = self._weights[i]
                cond1 = x < w - self._lambda
                cond2 = log_not(cond1)
                cond3 = x <= w + self._lambda
                cond4 = log_not(cond3)
                
                if i == 0:
                    x = torch.where(cond1, x + self._lambda, x)
                    x = torch.where(log_and(cond2, cond3, x <= c), w, x)
                    x = torch.where(log_and(cond4, x <= c), x - self._lambda, x)
                
                else:
                    c_left = self._cutoffs[i - 1]
                    x = torch.where(log_and(x > c_left, cond1), x + self._lambda, x)
                    x = torch.where(log_and(x > c_left, cond2, cond3, x <= c), w, x)
                    x = torch.where(log_and(cond4, x <= c), x - self._lambda, x)

            w = self._weights[-1]
            cond1 = x < w - self._lambda
            cond2 = log_not(cond1)
            cond3 = x <= w + self._lambda
            cond4 = log_not(cond3)
            
            c_left = self._cutoffs[-1]
            x = torch.where(log_and(x > c_left, cond1), x + self._lambda, x)
            x = torch.where(log_and(x > c_left, cond2, cond3), w, x)
            x = torch.where(cond4, x - self._lambda, x)
            return x
        else:
            for i, c in enumerate(self._cutoffs):
                if i == 0:
                    x = torch.where(x < c, self._weights[0], x)
                else:
                    x = torch.where(log_and(x >= self._cutoffs[i - 1], x < c), self._weights[i], x)

            x = torch.where(x > self._cutoffs[-1], self._weights[-1], x)
            return x


def main():
    torch.manual_seed(0)

    linear_fixed = LinearFixed('da')
    linear_fixed.set_weights(torch.Tensor([-1.0, 1.0]))
    linear_fixed.set_cutoffs(torch.Tensor([-0.5, 0.5]))
    #linear_fixed.set_cutoffs(None)
    linear_fixed.set_lambda(1.0)

    x = torch.tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.5988,
                      -1.5551, -0.3414,  1.8530,  0.4681, -0.1577,  
                       1.4437,  0.2660,  1.3894,  1.5863,  0.9463, 
                      -0.8437,  0.9318,  1.2590,  2.0050,  0.0537])
    
    res = torch.tensor([   -1.0,    -1.0, -0.0506, -0.2339, 0.7988,
                        -1.3551, -0.1414,  1.6530,  0.2681,   -0.0,
                         1.2437,  0.0660,  1.1894,  1.3863,    1.0,
                           -1.0,     1.0,  1.0590,  1.8050,    0.0])
    
    #print(linear_fixed(x))
    if torch.all(torch.abs(res - linear_fixed(x)) < 1e-6):
        print('Passed Test 1.')
    else:
        print('Failed Test 1.')

    constant_fixed = ConstantFixed('da')
    constant_fixed.set_weights(torch.Tensor([-1.0, 1.0]))
    constant_fixed.set_cutoffs(torch.Tensor([-0.2, 0.2]))

    res = torch.tensor([-1.0, -1.0, -1.0, -1.0,  1.0,
                        -1.0, -1.0,  1.0,  1.0,  0.0,
                         1.0,  1.0,  1.0,  1.0,  1.0,
                        -1.0,  1.0,  1.0,  1.0,  0.0])
    
    print(constant_fixed(x))
    
    if torch.all(torch.abs(res - constant_fixed(x)) < 1e-6):
        print('Passed Test 2.')
    else:
        print('Failed Test 2.')

if __name__ == '__main__':
    main()
