import numpy as np

import torch
from torch import nn
import torch.utils.data
from torch.autograd import Variable
from model_mlp import MLP
from min_norm_solvers import MinNormSolver

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_d_paretomtl_init(grads,value,weights,i):
    """ 
    calculate the gradient direction for ParetoMTL initialization 
    """
    
    flag = False
    nobj = value.shape
   
    # check active constraints
    current_weight = weights[i]
    rest_weights = weights
    w = rest_weights - current_weight
    
    gx =  torch.matmul(w,value/torch.norm(value))
    idx = gx >  0
   
    # calculate the descent direction
    if torch.sum(idx) <= 0:
        flag = True
        return flag, torch.zeros(nobj)
    if torch.sum(idx) == 1:
        sol = torch.ones(1).to(device).float()
    else:
        vec =  torch.matmul(w[idx],grads)
        sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))])


    weight0 =  torch.sum(torch.stack([sol[j] * w[idx][j ,0] for j in torch.arange(0, torch.sum(idx))]))
    weight1 =  torch.sum(torch.stack([sol[j] * w[idx][j ,1] for j in torch.arange(0, torch.sum(idx))]))
    weight = torch.stack([weight0,weight1])
   
    
    return flag, weight


def get_d_paretomtl(grads,value,weights,i):
    """ calculate the gradient direction for ParetoMTL """
    
    # check active constraints
    current_weight = weights[i]
    rest_weights = weights 
    w = rest_weights - current_weight
    
    gx =  torch.matmul(w,value/torch.norm(value))
    idx = gx >  0
    

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        sol, nd = MinNormSolver.find_min_norm_element([[grads[t]] for t in range(len(grads))])
        return torch.tensor(sol).to(device).float()


    vec =  torch.cat((grads, torch.matmul(w[idx],grads)))
    sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))])


    weight0 =  sol[0] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,0] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight1 =  sol[1] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,1] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight = torch.stack([weight0,weight1])
    
    return weight


def circle_points(r, n):
    """
    generate evenly distributed unit preference vectors for two tasks
    """
    circles = []
    for r, n in zip(r, n):
        t = np.linspace(0, 0.5 * np.pi, n)
        x = r * np.cos(t)
        y = r * np.sin(t)
        circles.append(np.c_[x, y])
    return circles


class OR_model_YS(nn.Module):
    def __init__(self, input_size):
        super(OR_model_YS, self).__init__()
        self.input_size = input_size
        self.model = MLP(input_size = self.input_size)

    def fit(self,x, a, s, y, e, r, est_r_1, est_r_0, mu0, mu1, bar_mu0, bar_mu1, tilde_mu0, tilde_mu1
              ,niter, npref, pref_idx, batch_size, lr,lamb):

        # generate #npref preference vectors
        n_tasks = 2
        ref_vec = torch.tensor(circle_points([1], [npref])[0]).to(device).float()

        # optimizer
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=lamb)

        # optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90], gamma=0.5)

        # load dataset
        num_sample = len(x)
        total_batch = num_sample // batch_size
        y[np.isnan(y)] = 0

        self.model.to(device)


        # find the initial solution
        for t in range(5):

            self.model.train()
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)

            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size * idx:(idx + 1) * batch_size]
                sub_x = torch.Tensor(x[selected_idx]).to(device)  # matrix
                sub_a = torch.Tensor(a.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_s = torch.Tensor(s[selected_idx]).to(device)  # matrix
                sub_y = torch.Tensor(y[selected_idx]).to(device)  # matrix
                #
                sub_e = torch.Tensor(e.reshape(-1, 1)[selected_idx]).to(device)  #
                sub_r = torch.Tensor(r[selected_idx]).to(device)  # matrix
                sub_est_r1 = torch.Tensor(est_r_1.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_est_r0 = torch.Tensor(est_r_0.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                #
                sub_mu0 = torch.Tensor(mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_mu1 = torch.Tensor(mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu0 = torch.Tensor(bar_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu1 = torch.Tensor(bar_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu0 = torch.Tensor(tilde_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu1 = torch.Tensor(tilde_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector


                grads = {}
                losses_vec = []

                for i in range(n_tasks):
                    if i == 0:
                        optimizer.zero_grad()
                        pred = self.model.forward(sub_x)
                        # estimated_Vs, we let the exp(-V(s)) and exp(-V(s)) as the loss function.
                        loss_S = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_mu1 - (1 - pred.reshape(-1, 1)) * sub_mu0 \
                                            - pred.reshape(-1, 1) * sub_a * (sub_s - sub_mu1) / sub_e \
                                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_s - sub_mu0) / (1 - sub_e)))
                        loss_S.backward()
                        grads[i] = []
                        for param in self.model.parameters():
                            if param.grad is not None:
                                grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
                        losses_vec.append(loss_S.data)
                    else:
                        optimizer.zero_grad()
                        pred = self.model.forward(sub_x)
                        loss_Y = torch.exp(torch.mean(
                            - pred.reshape(-1, 1) * sub_bar_mu1 - (1 - pred.reshape(-1, 1)) * sub_bar_mu0 \
                            - pred.reshape(-1, 1) * sub_a * (sub_tilde_mu1 - sub_bar_mu1) / sub_e \
                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_tilde_mu0 - sub_bar_mu0) / (
                                    1 - sub_e) \
                            - pred.reshape(-1, 1) * sub_a * sub_r * (sub_y - sub_tilde_mu1) / (
                                    sub_e * sub_est_r1) \
                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * sub_r * (sub_y - sub_tilde_mu0) / (
                                    (1 - sub_e) * sub_est_r0)))
                        loss_Y.backward()
                        grads[i] = []

                        for param in self.model.parameters():
                            if param.grad is not None:
                                grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
                        losses_vec.append(loss_Y.data)


                grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
                grads = torch.stack(grads_list)

                # calculate the weights
                losses_vec = torch.stack(losses_vec)
                flag, weight_vec = get_d_paretomtl_init(grads, losses_vec, ref_vec, pref_idx)
                # print("weight_vec1:",weight_vec)

                # early stop once a feasible solution is obtained
                if flag == True:
                    print("fealsible solution is obtained.")
                    break
                else:
                    continue

                pred = self.model.forward(sub_x)
                loss_S = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_mu1 - (1 - pred.reshape(-1, 1)) * sub_mu0 \
                                    - pred.reshape(-1, 1) * sub_a * (sub_s - sub_mu1) / sub_e \
                                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_s - sub_mu0) / (1 - sub_e)))

                loss_Y = torch.exp(torch.mean(
                    - pred.reshape(-1, 1) * sub_bar_mu1 - (1 - pred.reshape(-1, 1)) * sub_bar_mu0 \
                    - pred.reshape(-1, 1) * sub_a * (sub_tilde_mu1 - sub_bar_mu1) / sub_e \
                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_tilde_mu0 - sub_bar_mu0) / (
                            1 - sub_e) \
                    - pred.reshape(-1, 1) * sub_a * sub_r * (sub_y - sub_tilde_mu1) / (
                            sub_e * sub_est_r1) \
                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * sub_r * (sub_y - sub_tilde_mu0) / (
                            (1 - sub_e) * sub_est_r0)))

                # optimization step
                for i in range(len(losses_vec)):
                    if i == 0:
                        loss_total = weight_vec[i] * loss_S
                    else:
                        loss_total = loss_total + weight_vec[i] * loss_Y

                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()

            else:
                # continue if no feasible solution is found
                continue
            # break the loop once a feasible solutions is found
            break

        # run niter epochs of ParetoMTL
        weight_vec_iter = []
        for t in range(niter):
            self.model.train()
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)
            weight_vec_batch = []
            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size * idx:(idx + 1) * batch_size]
                sub_x = torch.Tensor(x[selected_idx]).to(device)  # matrix
                sub_a = torch.Tensor(a.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_s = torch.Tensor(s[selected_idx]).to(device)  # matrix
                sub_y = torch.Tensor(y[selected_idx]).to(device)  # matrix
                #
                sub_e = torch.Tensor(e.reshape(-1, 1)[selected_idx]).to(device)  #
                sub_r = torch.Tensor(r[selected_idx]).to(device)  # matrix
                sub_est_r1 = torch.Tensor(est_r_1.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_est_r0 = torch.Tensor(est_r_0.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                #
                sub_mu0 = torch.Tensor(mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_mu1 = torch.Tensor(mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu0 = torch.Tensor(bar_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu1 = torch.Tensor(bar_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu0 = torch.Tensor(tilde_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu1 = torch.Tensor(tilde_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector


                grads = {}
                losses_vec = []
                for i in range(n_tasks):
                    optimizer.zero_grad()
                    if i == 0:
                        # estimated_Vs
                        pred = self.model.forward(sub_x)
                        loss_S = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_mu1 - (1 - pred.reshape(-1, 1)) * sub_mu0 \
                                            - pred.reshape(-1, 1) * sub_a * (sub_s - sub_mu1) / sub_e \
                                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_s - sub_mu0) / (1 - sub_e)))
                        loss_S.backward()
                        losses_vec.append(loss_S.data)
                    else:
                        pred = self.model.forward(sub_x)
                        loss_Y = torch.exp(torch.mean(
                            - pred.reshape(-1, 1) * sub_bar_mu1 - (1 - pred.reshape(-1, 1)) * sub_bar_mu0 \
                            - pred.reshape(-1, 1) * sub_a * (sub_tilde_mu1 - sub_bar_mu1) / sub_e \
                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_tilde_mu0 - sub_bar_mu0) / (
                                    1 - sub_e) \
                            - pred.reshape(-1, 1) * sub_a * sub_r * (sub_y - sub_tilde_mu1) / (
                                    sub_e * sub_est_r1) \
                            - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * sub_r * (sub_y - sub_tilde_mu0) / (
                                    (1 - sub_e) * sub_est_r0)))
                        loss_Y.backward()
                        losses_vec.append(loss_Y.data)

                    grads[i] = []

                    for param in self.model.parameters():
                        if param.grad is not None:
                            grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))

                grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
                grads = torch.stack(grads_list)

                # calculate the weights
                losses_vec = torch.stack(losses_vec)
                weight_vec = get_d_paretomtl(grads, losses_vec, ref_vec, pref_idx)

                normalize_coeff = 2 / torch.sum(torch.abs(weight_vec))
                weight_vec = weight_vec * normalize_coeff

                weight_vec_ = torch.tensor(weight_vec).unsqueeze_(0)
                weight_vec_batch.append(weight_vec_)

                # optimization step
                pred = self.model.forward(sub_x)
                loss_S = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_mu1 - (1 - pred.reshape(-1, 1)) * sub_mu0 \
                                    - pred.reshape(-1, 1) * sub_a * (sub_s - sub_mu1) / sub_e \
                                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_s - sub_mu0) / (1 - sub_e)))
                loss_Y = torch.exp(torch.mean(
                    - pred.reshape(-1, 1) * sub_bar_mu1 - (1 - pred.reshape(-1, 1)) * sub_bar_mu0 \
                    - pred.reshape(-1, 1) * sub_a * (sub_tilde_mu1 - sub_bar_mu1) / sub_e \
                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_tilde_mu0 - sub_bar_mu0) / (
                            1 - sub_e) \
                    - pred.reshape(-1, 1) * sub_a * sub_r * (sub_y - sub_tilde_mu1) / (
                            sub_e * sub_est_r1) \
                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * sub_r * (sub_y - sub_tilde_mu0) / (
                            (1 - sub_e) * sub_est_r0)))
                for i in range(len(losses_vec)):
                    if i == 0:
                        loss_total = weight_vec[i] * loss_S
                    else:
                        loss_total = loss_total + weight_vec[i] * loss_Y
                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()


            scheduler.step()

            weight_vec_batch_ = torch.cat(weight_vec_batch, dim=0)
            weight_vec_batch_mean = torch.mean(weight_vec_batch_, dim=0, keepdim=True)
            weight_vec_iter.append(weight_vec_batch_mean)

        weight_vec_iter_ = torch.cat(weight_vec_iter, dim=0)
        weight_ = torch.mean(weight_vec_iter_, dim=0, keepdim=True)
        weight_1 = weight_.squeeze(dim=0)
        return weight_1

    def predict(self, x):
        x = torch.Tensor(x).to(device)
        pred = self.model.forward(x)
        pred = pred.detach().cpu().numpy().flatten()
        pred = np.random.binomial(1, pred)
        return pred

class OR_model_linear(nn.Module):
    def __init__(self, input_size):
        super(OR_model_linear, self).__init__()
        self.input_size = input_size
        self.model = MLP(input_size=self.input_size)
    def fit_w(self,x, a, s, y, e, r, est_r_1, est_r_0, mu0, mu1, bar_mu0, bar_mu1, tilde_mu0, tilde_mu1
              ,niter, npref, pref_idx, weight_ours, batch_size,lr,lamb):

        # generate #npref preference vectors
        ref_vec = torch.tensor(circle_points([1], [npref])[0]).float().to(device)

        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=lamb)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90], gamma=0.5)

        early_stop = 0
        last_loss = 1e9
        # load dataset
        num_sample = len(x)
        total_batch = num_sample // batch_size
        y[np.isnan(y)] = 0

        self.model.to(device)

        for t in range(niter):
            self.model.train()
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)
            epoch_loss = 0
            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size * idx:(idx + 1) * batch_size]
                sub_x = torch.Tensor(x[selected_idx]).to(device)  # matrix
                sub_a = torch.Tensor(a.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_s = torch.Tensor(s[selected_idx]).to(device)  # matrix
                sub_y = torch.Tensor(y[selected_idx]).to(device)  # matrix
                #
                sub_e = torch.Tensor(e.reshape(-1, 1)[selected_idx]).to(device)  #
                sub_r = torch.Tensor(r[selected_idx]).to(device)  # matrix
                sub_est_r1 = torch.Tensor(est_r_1.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                sub_est_r0 = torch.Tensor(est_r_0.reshape(-1, 1)[selected_idx]).to(device)  # matrix
                #
                sub_mu0 = torch.Tensor(mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_mu1 = torch.Tensor(mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu0 = torch.Tensor(bar_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_bar_mu1 = torch.Tensor(bar_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu0 = torch.Tensor(tilde_mu0.reshape(-1, 1)[selected_idx]).to(device)  # vector
                sub_tilde_mu1 = torch.Tensor(tilde_mu1.reshape(-1, 1)[selected_idx]).to(device)  # vector


                pred = self.model.forward(sub_x)

                # estimated_Vs
                loss_S = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_mu1 - (1 - pred.reshape(-1, 1)) * sub_mu0 \
                                    - pred.reshape(-1, 1) * sub_a * (sub_s - sub_mu1) / sub_e \
                                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_s - sub_mu0) / (1 - sub_e)))

                loss_Y = torch.exp(torch.mean(- pred.reshape(-1, 1) * sub_bar_mu1 - (1 - pred.reshape(-1, 1)) * sub_bar_mu0 \
                                    - pred.reshape(-1, 1) * sub_a * (sub_tilde_mu1 - sub_bar_mu1) / sub_e \
                                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * (sub_tilde_mu0 - sub_bar_mu0) / (
                                                1 - sub_e) \
                                    - pred.reshape(-1, 1) * sub_a * sub_r * (sub_y - sub_tilde_mu1) / (
                                                sub_e * sub_est_r1) \
                                    - (1 - pred.reshape(-1, 1)) * (1 - sub_a) * sub_r * (sub_y - sub_tilde_mu0) / (
                                            (1 - sub_e) * sub_est_r0)))


                optimizer.zero_grad()
                loss = ref_vec[pref_idx][0] * loss_S + ref_vec[pref_idx][1] * loss_Y

                loss.backward()
                optimizer.step()
                epoch_loss += loss.detach().cpu().numpy()
            scheduler.step()

            if t > 0.5*(niter):
                break
            if epoch_loss > last_loss - 1e-4:
                if early_stop > 10:
                    break
                early_stop += 1

            last_loss = epoch_loss

            if t == niter - 1:
                print("[Waring] Reach present epochs, it seems does not converge.")

                epoch_loss += loss.detach().cpu().numpy()

        return ref_vec[pref_idx]

    def predict(self, x):
        x = torch.Tensor(x).to(device)
        pred = self.model.forward(x)
        pred = pred.detach().cpu().numpy().flatten()
        pred = np.random.binomial(1, pred)
        return pred
