import torch
import torch.nn as nn

import scipy
import scipy.io
import numpy as np
from numpy import random

LOWER = 0.000005

def get_data(n):
    Z = random.normal(size=n)
    return Z

def get_datas(n):
    A_mean = [0, 0]
    A_cov = [[1, 0], [0, 1]]
    A = np.random.multivariate_normal(A_mean, A_cov, n)
    Z1, Z2 = A[:, 0].reshape(-1), A[:, 1].reshape(-1)
    return Z1, Z2

class Toy(nn.Module):
    def __init__(self, args):
        super(Toy, self).__init__()
        self.centers = torch.Tensor([
            [-3.0, 0],
            [3.0, 0]])

        self.num_data = 20  # number of empirical data
        self.data_mean = 0  # mean of population data distribution
        self.data_sig = 1.  # std of population distribution
        data = scipy.io.loadmat('./data/synthetic_NC_data.mat')
        z1 = data['z1'].reshape(-1)
        z2 = data['z2'].reshape(-1)
        self.emp_data_set = torch.stack([torch.from_numpy(z1),
                                         torch.from_numpy(z2)])
        # self.emp_data_set = torch.normal(
        #     self.data_mean, self.data_sig, size=(self.num_data, 2)).detach()
        self.emp_data_mean = torch.mean(self.emp_data_set, dim=1).view([2, 1])

        self.p1 = 3.5
        self.p2 = -3.5
        self.p3 = -1

        self.fd_eps = args.fd_eps

    def __forward__(self, x1, x2, z=[0, 0], compute_grad=False):
        f1, f2 = self.__compute_f12__(x1, x2, z)
        f = torch.cat([f1.view(-1, 1), f2.view(-1, 1)], -1)
        if compute_grad:
            g = self.__compute_grad__(x1, x2, z)
            return f, g
        else:
            return f

    def __compute_f12__(self, x1, x2, z = [0, 0]):
        f1 = torch.clamp((0.5 * (-x1 - 7) - torch.tanh(-x2)).abs(), LOWER).log() + 6
        f2 = torch.clamp((0.5 * (-x1 + 3) + torch.tanh(-x2) + 2).abs(), LOWER).log() + 6
        c1 = torch.clamp(torch.tanh((x2) * 0.5), 0)
        c2 = torch.clamp(torch.tanh(-(x2) * 0.5), 0)

        f1_sq = ((-x1 + self.p1).pow(2) + 0.5 * (-x2 + self.p3).pow(2)) / 10 - 20
        f2_sq = ((-x1 + self.p2).pow(2) + 0.5 * (-x2 + self.p3).pow(2)) / 10 - 20

        f1 = f1 * c1 + (f1_sq - 2 * z[0] * x1 - 5.5 * z[1] * x2) * c2
        f2 = f2 * c1 + (f2_sq + 2 * z[0] * x1 - 5.5 * z[1] * x2) * c2
        return f1, f2

    def __compute_grad__(self, x1, x2, z= [0, 0]):
        f11_plus, f21_plus = self.__compute_f12__(x1+self.fd_eps, x2, z)
        f12_plus, f22_plus = self.__compute_f12__(x1, x2+self.fd_eps, z)
        f1, f2 = self.__compute_f12__(x1, x2, z)
        g11 = (f11_plus - f1) / self.fd_eps
        g12 = (f12_plus - f1) / self.fd_eps
        g21 = (f21_plus - f2) / self.fd_eps
        g22 = (f22_plus - f2) / self.fd_eps
        g = torch.Tensor([[g11, g21], [g12, g22]])
        return g

    def forward(self, x, compute_grad=False, data_type='pop', batch_size=1):
        # data types: 'stoch', 'emp', 'pop'
        x1 = x[0]
        x2 = x[1]

        if data_type == 'pop':
            z = [0, 0]
        if data_type == 'stoch':
            batch_idx = np.random.choice(self.num_data, batch_size)
            z = self.emp_data_set[:, batch_idx]  # 2*batch_size
            z_m = torch.mean(z, dim=1).view([2, 1])
            z = z_m
        if data_type == 'emp':
            z = self.emp_data_mean

        if compute_grad:
            f, g = self.__forward__(x1, x2, z=z, compute_grad=True)
            return f, g
        else:
            f = self.__forward__(x1, x2, z=z, compute_grad=False)
            return f

    def batch_forward(self, x,
                      data_type='pop', compute_grad=False):
        # data types: 'emp', 'pop'

        if compute_grad:
            g = []
            g_emp = []
            f = []
            for i, x_ in enumerate(x):
                f_, g_ = self.forward(x_, compute_grad=True,
                                      data_type='pop', batch_size=1)
                g.append(g_.clone())
                f.append(f_.clone())
                f_emp, g_emp_ = self.forward(x_, compute_grad=True,
                                             data_type='emp', batch_size=1)
                g_emp.append(g_emp_.clone())

            return torch.stack(f), torch.stack(g), torch.stack(g_emp)

        else:
            x1 = x[:, 0]
            x2 = x[:, 1]

            if data_type == 'pop':
                z = [0, 0]
            # for plotting emperical objective
            elif data_type == 'emp':
                print('torch.sum( x * self.emp_data_mean.view(-1), dim=1)',
                      torch.sum(x * self.emp_data_mean.view(-1), dim=1).shape)
                z = self.emp_data_mean

            f = self.__forward__(x1, x2, z=z, compute_grad=False)
            return f