import torch
import torch.nn as nn

import scipy
import scipy.io
import numpy as np

LOWER = 0.000005

class Convex(nn.Module):
    def __init__(self, args):
        super(Convex, self).__init__()
        self.centers = torch.Tensor([
            [-3.0, 0],
            [3.0, 0]])

        self.num_data = 20  # number of empirical data
        self.data_mean = 0  # mean of population data distribution
        self.data_sig = 1.  # std of population distribution
        data = scipy.io.loadmat('./data/synthetic_NC_data.mat')
        z1 = data['z1'].reshape(-1)
        z2 = data['z2'].reshape(-1)
        # Z = get_data(self.num_data)
        self.emp_data_set = torch.stack([torch.from_numpy(z1),
                                         torch.from_numpy(z2)])
        # self.emp_data_set = torch.normal(
        #     self.data_mean, self.data_sig, size=(self.num_data, 2)).detach()
        self.emp_data_mean = torch.mean(self.emp_data_set, dim=1).view([2, 1])

        self.b11 = 1
        self.b12 = 2
        self.b13 = 1

        self.b21 = 1
        self.b22 = 3
        self.b23 = 2

        self.fd_eps = args.fd_eps

    def __forward__(self, x1, x2, z=[0, 0], compute_grad=False):
        f1, f2, f3 = self.__compute_f123__(x1, x2, z)
        f = torch.cat([f1.view(-1, 1), f2.view(-1, 1), f3.view(-1, 1)], -1)
        if compute_grad:
            g = self.__compute_grad__(x1, x2, z)
            return f, g
        else:
            return f

    def __compute_f123__(self, x1, x2, z = [0, 0]):
        f1 = 0.5 * self.b11 * (x1**2 + x2**2) - self.b21 * (z[0]*x1+z[1]*x2)
        f2 = 0.5 * self.b12 * (x1 ** 2 + x2 ** 2) - self.b22 * (z[0] * x1 + z[1] * x2)
        f3 = 0.5 * self.b13 * (x1 ** 2 + x2 ** 2) - self.b23 * (z[0] * x1 + z[1] * x2)
        return f1, f2, f3

    def __compute_grad__(self, x1, x2, z= [0, 0]):
        f11_plus, f21_plus, f31_plus = self.__compute_f123__(x1+self.fd_eps, x2, z)
        f12_plus, f22_plus, f32_plus = self.__compute_f123__(x1, x2+self.fd_eps, z)
        f1, f2, f3 = self.__compute_f123__(x1, x2, z)
        g11 = (f11_plus - f1) / self.fd_eps
        g12 = (f12_plus - f1) / self.fd_eps
        g21 = (f21_plus - f2) / self.fd_eps
        g22 = (f22_plus - f2) / self.fd_eps
        g31 = (f31_plus - f3) / self.fd_eps
        g32 = (f32_plus - f3) / self.fd_eps
        g = torch.Tensor([[g11, g21, g31], [g12, g22, g32]])
        return g

    def forward(self, x, compute_grad=False, data_type='pop', batch_size=1):
        # data types: 'stoch', 'emp', 'pop'
        x1 = x[0]
        x2 = x[1]

        if data_type == 'pop':
            z = [0, 0]
        if data_type == 'stoch':
            batch_idx = np.random.choice(self.num_data, batch_size)
            z = self.emp_data_set[:, batch_idx]  # 2*batch_size
            z_m = torch.mean(z, dim=1).view([2, 1])
            z = z_m
        if data_type == 'emp':
            z = self.emp_data_mean

        if compute_grad:
            f, g = self.__forward__(x1, x2, z=z, compute_grad=True)
            return f, g
        else:
            f = self.__forward__(x1, x2, z=z, compute_grad=False)
            return f

    def batch_forward(self, x,
                      data_type='pop', compute_grad=False):
        # data types: 'emp', 'pop'

        if compute_grad:
            g = []
            g_emp = []
            f = []
            for i, x_ in enumerate(x):
                f_, g_ = self.forward(x_, compute_grad=True,
                                      data_type='pop', batch_size=1)
                g.append(g_.clone())
                f.append(f_.clone())
                f_emp, g_emp_ = self.forward(x_, compute_grad=True,
                                             data_type='emp', batch_size=1)
                g_emp.append(g_emp_.clone())

            return torch.stack(f), torch.stack(g), torch.stack(g_emp)

        else:
            x1 = x[:, 0]
            x2 = x[:, 1]

            if data_type == 'pop':
                z = [0, 0]
            # for plotting emperical objective
            elif data_type == 'emp':
                print('torch.sum( x * self.emp_data_mean.view(-1), dim=1)',
                      torch.sum(x * self.emp_data_mean.view(-1), dim=1).shape)
                z = self.emp_data_mean

            f = self.__forward__(x1, x2, z=z, compute_grad=False)
            return f