import numpy as np
from pandas import DataFrame
import os
import math
import pandas as pd
from utils import Config


def Normal_Dist(rng, num, dim, mean=0.0, sigma=1.0):
    _mean = np.zeros(dim) * mean                                 # mean vector with dimension dim and all values are 0
    _cov  = np.identity(dim) * sigma                             # covariance matrix with diagonal elements being sigma
    mat  = rng.multivariate_normal(_mean, _cov, num)       # generating num random vectors of dimension dim with mean and covariance as mean and cov
    return mat

def Uniform_Dist(rng, num, dim=1, low=0.0, high=1.0):
    mat = rng.uniform(low, high, size=(num, dim))
    return mat

def sigmoid(x, a=1, b=1, c=0):
    return a / (b + np.exp(-x)) + c

class Generator(object):
    def __init__(self, cfg) -> None:
        self.cfg = cfg

    def backW(self, rng, theta_w, X, add=0, scale=0.01):
        if scale <= 0.0005:
            noise = add
        else:
            noise = add + rng.normal(size=(len(X), 1)) * scale

        Z = np.sum(theta_w*X, axis=1, keepdims=True)
        Z = Z-np.mean(Z)
        P = sigmoid(noise + Z)
        W   = rng.binomial(1, P)
        return P, W
    
    def backY(self, rng, theta_y, X, add=0, scale=0.01):
        if scale <= 0.0005:
            noise = add
        else:
            noise = add + rng.normal(size=(len(X), 1)) * scale

        Z = np.sum(theta_y*(X**2), axis=1, keepdims=True)
        Z = Z-np.mean(Z)
        P = sigmoid(noise+Z)
        Y   = rng.binomial(1, P)
        return P, Y
    
    def backD(self, rng, theta_d, X, add=0, scale=0.01):
        if scale <= 0.0005:
            noise = add
        else:
            noise = add + rng.normal(size=(len(X), 1)) * scale

        Z = np.sum(theta_d*(X**1), axis=1, keepdims=True)
        Z = Z-np.mean(Z)
        _lambda = np.exp(noise+Z)
        D = rng.exponential(_lambda)
        return _lambda, D
    
    def testing(self, cfg=None, num=None, show=False):
        if cfg is None:
            cfg = self.cfg
        if num is None:
            num = cfg.num

        rng = self.rng

        if show: print("Gen - Testing. ---------")
        # The observed covariates X
        X = Normal_Dist(rng, num, cfg.dim)
        if show: print("X-shape-({},{})".format(X.shape[0], X.shape[1]))

        # The latent outcome Y0
        P_y0, Y0 = self.backY(rng, self.theta_y0, X, cfg.y0_add, cfg.noise_scale)
        if show: print("Y0-mean-({}).".format(Y0.mean().round(2)))

        # The latent outcome Y1
        P_y1, Y1 = self.backY(rng, self.theta_y1, X, cfg.y1_add, cfg.noise_scale)
        if show: print("Y1-mean-({}).".format(Y1.mean().round(2)))

        # The latent delayed function D0
        _lambda0, D0 = self.backD(rng, self.theta_d0, X, cfg.d0_add, cfg.noise_scale)
        if show: print("D0-mean-({}), D-min-({}), D-max-({}).".format(D0.mean().round(2),D0.min().round(2),D0.max().round(2)))

        # The latent delayed function D1
        _lambda1, D1 = self.backD(rng, self.theta_d1, X, cfg.d1_add, cfg.noise_scale)
        if show: print("D1-mean-({}), D-min-({}), D-max-({}).".format(D1.mean().round(2),D1.min().round(2),D1.max().round(2)))

        data_GT = [X, _lambda0, _lambda1, P_y0, P_y1, Y0, Y1]

        return data_GT

    def training(self, cfg=None, seed=None, show=True):
        if cfg is None:
            cfg = self.cfg
        if seed is None:
            seed = cfg.seed

        rng = np.random.RandomState(seed)
        self.rng = rng

        if show: print("Gen - Training. ---------")
        # Generate theta parameters: theta
        self.theta_w  = Uniform_Dist(rng, 1, cfg.dim, cfg.low, cfg.high)
        self.theta_y0 = Uniform_Dist(rng, 1, cfg.dim, cfg.low, cfg.high)
        self.theta_y1 = Uniform_Dist(rng, 1, cfg.dim, cfg.low, cfg.high)
        self.theta_d0 = Uniform_Dist(rng, 1, cfg.dim, -0.1, 0.1)
        self.theta_d1 = Uniform_Dist(rng, 1, cfg.dim, -0.1, 0.1)

        # The observed time T
        T = rng.exponential(1, size=(cfg.num, 1))
        if show: print("T-min-({}), T-max-({}).".format(T.min().round(2),T.max().round(2)))

        # The observed covariates X
        X = Normal_Dist(rng, cfg.num, cfg.dim)
        if show: print("X-shape-({},{})".format(X.shape[0], X.shape[1]))

        # The observed treatments W
        P_w, W = self.backW(rng, self.theta_w, X, cfg.w_add, cfg.noise_scale)
        if show: print("W-mean-({}).".format(W.mean().round(2)))

        # The latent outcome Y0
        P_y0, Y0 = self.backY(rng, self.theta_y0, X, cfg.y0_add, cfg.noise_scale)
        if show: print("Y0-mean-({}).".format(Y0.mean().round(2)))

        # The latent outcome Y1
        P_y1, Y1 = self.backY(rng, self.theta_y1, X, cfg.y1_add, cfg.noise_scale)
        if show: print("Y1-mean-({}).".format(Y1.mean().round(2)))

        # Show probability
        if show: 
            threshold = [0.03, 0.1, 0.3, 0.5, 0.7, 0.9, 0.97]
            pww_str = 'W--0.03-0.97: '
            py0_str = 'Y0-0.03-0.97: '
            py1_str = 'Y1-0.03-0.97: '
            for item in threshold:
                pww_str += "({:.2f}) ".format(np.mean(P_w  < item))
                py0_str += "({:.2f}) ".format(np.mean(P_y0 < item))
                py1_str += "({:.2f}) ".format(np.mean(P_y1 < item))
            print(pww_str)
            print(py0_str)
            print(py1_str)

        # The latent delayed function D0
        _lambda0, D0 = self.backD(rng, self.theta_d0, X, cfg.d0_add, cfg.noise_scale)
        if show: print("D0-mean-({}), D-min-({}), D-max-({}).".format(D0.mean().round(2),D0.min().round(2),D0.max().round(2)))

        # The latent delayed function D1
        _lambda1, D1 = self.backD(rng, self.theta_d1, X, cfg.d1_add, cfg.noise_scale)
        if show: print("D1-mean-({}), D-min-({}), D-max-({}).".format(D1.mean().round(2),D1.min().round(2),D1.max().round(2)))

        # The Observed Outcome Y(W,D,T)
        PY = W * P_y1 + (1-W) * P_y0
        Lam = W * _lambda1 + (1-W) * _lambda0
        D = W * D1 + (1-W) * D0
        G = W * Y1 + (1-W) * Y0

        Y = G * (T > D)
        print("Y(W,∞,D):{}.".format(G.mean().round(2)))
        print("Y(W,T,D):{}.".format(Y.mean().round(2)))

        # Dataset
        data_GT = [X, _lambda0, _lambda1, P_y0, P_y1, Y0, Y1]
    
        data_train = [X, W, T, D, Y, G, PY, Lam]

        return data_GT, data_train

cfg = Config()
for exp in range(cfg.exps):
    print("Gen - exp: {}.".format(exp))
    cfg.seed = cfg.seed_base + cfg.seed_mul * exp
    gen = Generator(cfg)
    data_GT, data_train = gen.training(show=False)
    test_GT = gen.testing(num=cfg.tnum, show=False)

    data_setting = f"{cfg.num}_{cfg.dim}_{cfg.y0_add}_{cfg.y1_add}_{cfg.noise_scale}"
    os.makedirs(os.path.dirname(f'./data/{data_setting}/{exp}/'), exist_ok=True)
    np.savez(f'./data/{data_setting}/{exp}/train.npz', X=data_train[0], W=data_train[1], T=data_train[2], 
             D=data_train[3], Y=data_train[4], G=data_train[5], P=data_train[6], L=data_train[7])
    np.savez(f'./data/{data_setting}/{exp}/valid.npz', X=data_GT[0], Lam0=data_GT[1], Lam1=data_GT[2], 
             P0=data_GT[3], P1=data_GT[4], Y0=data_GT[5], Y1=data_GT[6])
    np.savez(f'./data/{data_setting}/{exp}/test.npz', X=test_GT[0], Lam0=test_GT[1], Lam1=test_GT[2], 
             P0=test_GT[3], P1=test_GT[4], Y0=test_GT[5], Y1=test_GT[6])

