import os
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def dfTotensor(df):
    r"""
    Functs: - given a DataFrame, convert it into torch Tensor
    """
    return (torch.from_numpy(df.values)).float()


def print_nnmodule(f):
    r"""
    Functs: - print weight and bias of a nn.module
    """
    for m in f.function:
        if isinstance(m, nn.Linear):
            print(m.weight)
            print(m.bias)


class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """

    def __init__(self, patience, min_delta, verbose=False):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.verbose = verbose

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif abs(self.best_loss - val_loss) > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif abs(self.best_loss - val_loss) < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}") if self.verbose else None
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

class Lasso(nn.Module):
    r"""
    Functs: - Lasso feature selection
            - we define a nn.parameter with shape (p,)
            - The input X is with (n,p), where n is sample size and p is how many features
    """

    def __init__(self, p):
        super(Lasso, self).__init__()
        self.omega = nn.Parameter(torch.ones(p))
        self.trained = False

    def l1norm(self, ):
        r"""
        Functs: - return L1 norm the omega
        """
        return torch.norm(self.omega, p=1)

    def forward(self, X):
        return self.omega * X


class NonLinearF(nn.Module):
    r"""
    Functs: - use .fit to train the model,
            - use .forward (for torch.tensor) or .predict (for pd.DataFrame) to do inference (or training of f_J_theta)
    """

    def __init__(self, in_dim, baseinit, hidden=1, out_dim=1, verbose=True):
        super(NonLinearF, self).__init__()

        self.in_dim = in_dim
        modules = [nn.Linear(in_dim, hidden),
                   nn.Sigmoid(),
                   nn.Linear(hidden, out_dim)]

        self.function = nn.Sequential(*modules)
        self.trained = False
        self.verbose = verbose

        if baseinit:
            for m in self.function:
                if isinstance(m, nn.Linear):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def fit(self, covariates, target, opt, num_iters, lr, step=None, gamma=None):

        if isinstance(covariates, pd.core.frame.DataFrame):
            covariates = dfTotensor(covariates)
            target = dfTotensor(target)

        optimizer = torch.optim.SGD(self.function.parameters(), lr=lr) if opt == 'SGD' else torch.optim.Adam(
            self.function.parameters(), lr=lr)
        if step is not None:
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma)

        loss_func = nn.MSELoss()

        for itera in range(num_iters + 1):
            prediction = self.function(covariates)
            loss = loss_func(prediction, target)
            if self.verbose:
                if itera % (num_iters // 5) == 0:
                    print('iteration: {:d}, loss: {:.3f}, lr: {:.4f}'.format(int(itera), float(loss),
                                                                             optimizer.state_dict()['param_groups'][0][
                                                                                 'lr']))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if step is not None:
                scheduler.step()

        self.trained = True
        for param in self.function.parameters():
            param.requires_grad = False

    def forward(self, covariates):
        return self.function(covariates)

    def predict(self, covariates):
        # assert self.trained, 'NonLinearF must be trained befored prediction'
        covariates = dfTotensor(covariates)
        return self.function(covariates).detach().numpy()


class LassoOptSet():
    r"""
    Under the graphical assumption of G4567_X10 (@ denotes changed due to X10->X9, X10->X5)

    Functs: - For a given training split, e.g. 1456,
                1. read from BASE/1456.csv
                2. estimating structual equivations for each variable that is descentant of X_M
                   in this case, f_regen7,6,2,4,3,5
                3. Lasso-based optimal subset searching
                    - init a learnable Lasso filter, lassoFilter, with parameter \omega size (n_S=8)
                    - shuffle X_do*={X_1,3,4,5,6,7,8,9} and regenerate their descentants={X_2}, call the data P0
                    - init a NonLinear network, f_S_prime, predY=f_S_prime(sel_X_S_prime.detach(),X_M)
                      with sel_X_S_prime = lassoFilter(X_S)
                      train f_S_prime in P0 by MSE(Y,predY) until convergence, fix it

                    - init a NonLinear network, f_J_theta, replace X_M={X_1,X_8,X_9} with f_J_theta(PA_M={X_6,7,10,11,Y})
                      and regenerate its descentant={X_7,6,2,4,3,5}, call the data P1
                      in P1, predict Y by the fixed f_S_prime(sel_X_S_prime.detach(),X_M), and compute the loss as MSE(Y,predY)
                      optimize f_J_theta by -loss until convergence, fix it

                    - use the fixed f_J_theta to regenerate data again,
                      and use the fixed f_S_prime to calculate the MSE loss again,
                      use this re-computed loss, added with L1 penalty, to optimze parameter of lassoFilter, i.e., \omega.

                4. Note that there are 3 MSE loss, name them loss_S_prime, loss_J_theta, and MSE_lasso, respectively

            - We expect the optimized lassoFilter has a similar choice on optimal subset
                 and the predictor f_S_prime(sel_X_S_prime.detach(),X_M) has similar maxMSE performance in test sets.
    """

    def __init__(self, trainsplit, params, seed=1234, need_norm=True, verbose=False):

        self.params = params

        self.S = ['X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X10', 'X11']
        self.M = ['X1', 'X8', 'X9']
        self.doStar = ['X1', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8', 'X9']

        self.verbose = verbose

        self.log = dict()
        self.log['iter'] = list()
        self.log['hstar'] = list()
        self.log['l1norm'] = list()
        self.log['omega'] = list()
        self.log['maxmse'] = list()

        BASE = '/home/anonymous/data/CausallyInvariant_output/ADNIB2FAQ/FindOptSets/'
        self.trainsplit = trainsplit
        self.seed = seed

        trainfilename = os.path.join(BASE, '{}.csv'.format(trainsplit))
        self.trainDF = pd.read_csv(trainfilename)

        self.earlystoper = EarlyStopping(patience=self.params['patience'], min_delta=0.01, verbose=False)

        self.need_norm = need_norm
        if self.need_norm:
            for var in self.S + self.M + ['Y']:
                mean = self.trainDF[[var]].mean().values[0]
                std = self.trainDF[[var]].std().values[0]
                self.trainDF[[var]] = (self.trainDF[[var]] - mean) / std

        self.testfolder = os.path.join(BASE, '{}'.format(trainsplit))
        self.testDFs = list()
        for filename in os.listdir(self.testfolder):
            testDF = pd.read_csv(os.path.join(self.testfolder, filename))
            if self.need_norm:
                # normalization
                for var in self.S + self.M + ['Y']:
                    mean = testDF[[var]].mean().values[0]
                    std = testDF[[var]].std().values[0]
                    testDF[[var]] = (testDF[[var]] - mean) / std
            self.testDFs.append(testDF)

        # we need to estimate 6 f_regens, i.e. 2,3,4,5,6,7
        self.f_regen7 = NonLinearF(in_dim=1, baseinit=self.params['regenBaseinit'], hidden=1, verbose=False)
        self.f_regen6 = NonLinearF(in_dim=1, baseinit=self.params['regenBaseinit'], hidden=1, verbose=False)
        self.f_regen2 = NonLinearF(in_dim=5, baseinit=self.params['regenBaseinit'], hidden=5, verbose=False)
        self.f_regen4 = NonLinearF(in_dim=1, baseinit=self.params['regenBaseinit'], hidden=1, verbose=False)
        self.f_regen3 = NonLinearF(in_dim=2, baseinit=self.params['regenBaseinit'], hidden=1, verbose=False)
        self.f_regen5 = NonLinearF(in_dim=3, baseinit=self.params['regenBaseinit'], hidden=1, verbose=False)

        print('Train split: {}'.format(self.trainsplit)) if self.verbose else None
        self.estimate_f_regen7()
        self.estimate_f_regen6()
        self.estimate_f_regen2()
        self.estimate_f_regen4()
        self.estimate_f_regen3()
        self.estimate_f_regen5()

        self.sample_P0()

    def estimate_f_regen7(self, ):
        r"""
        Functs: - learn X_7 = self.f_regen7(X_8)
        """
        X = self.trainDF[['X8']]
        Y = self.trainDF[['X7']]

        self.f_regen7.fit(X, Y, opt=self.params['regenOpt'], num_iters=2000, lr=0.1)

    def estimate_f_regen6(self, ):
        r"""
        Functs: - learn X_6 = self.f_regen6(X_7)
        """
        X = self.trainDF[['X7']]
        Y = self.trainDF[['X6']]

        self.f_regen6.fit(X, Y, opt=self.params['regenOpt'], num_iters=2000, lr=0.1)

    def estimate_f_regen2(self, ):
        r"""
        Functs: - learn X_2 = self.f_regen2(X_1,X_6,X_7,X_8,Y)
        """
        X = self.trainDF[['X1', 'X6', 'X7', 'X8', 'Y']]
        Y = self.trainDF[['X2']]

        self.f_regen2.fit(X, Y, opt=self.params['regenOpt'], num_iters=5000, lr=0.1)

    def estimate_f_regen4(self, ):
        r"""
        Functs: - learn X_4 = self.f_regen4(X_2)
        """
        X = self.trainDF[['X2']]
        Y = self.trainDF[['X4']]

        self.f_regen4.fit(X, Y, opt=self.params['regenOpt'], num_iters=2000, lr=0.1)

    def estimate_f_regen3(self, ):
        r"""
        Functs: - learn X_3 = self.f_regen3(X_1,X_4)
        """
        X = self.trainDF[['X1', 'X4']]
        Y = self.trainDF[['X3']]

        self.f_regen3.fit(X, Y, opt=self.params['regenOpt'], num_iters=5000, lr=0.1)

    def estimate_f_regen5(self, ):
        r"""
        Functs: - learn X_5 = self.f_regen5(X_1,X_3)
        """
        X = self.trainDF[['X1', 'X3', 'X10']]
        Y = self.trainDF[['X5']]

        self.f_regen5.fit(X, Y, opt=self.params['regenOpt'], num_iters=2000, lr=0.1)

    def sample_P0(self, ):
        r"""
        Functs: - shuffle X_do*={X1,3,4,5,6,7,8,9}, regenerate X2, call the data as self.P0
        """
        # shuffle @X_do*
        self.P0 = self.trainDF.copy(deep=True)
        for variable in self.doStar:
            self.P0.loc[:, variable] = self.P0.loc[:, variable].sample(frac=1, random_state=self.seed).values

        # regenerate X2 by f_regen2(X1,X6,X7,X8,Y)
        X = self.P0.loc[:, ['X1', 'X6', 'X7', 'X8', 'Y']]
        predX = self.f_regen2.predict(X)

        self.P0.loc[:, 'X2'] = predX

    def estimate_f_S_prime(self, lassoFilter):
        r"""
        Functs: - sample from data distribution P0 and fit f_S_prime in it
        """
        # prepare filted data
        X_S = dfTotensor(self.P0.loc[:, self.S])
        filtedX_S = lassoFilter(X_S).detach()

        X_M = dfTotensor(self.P0.loc[:, self.M])

        XX = torch.cat([filtedX_S, X_M], dim=1)
        YY = dfTotensor(self.P0[['Y']])

        # init and train f_S_prime
        f_S_prime = NonLinearF(in_dim=len(self.S) + len(self.M), baseinit=self.params['fSBaseinit'], hidden=5,
                               verbose=False)
        f_S_prime.fit(XX, YY, opt=self.params['fSOpt'], num_iters=self.params['fSIter'], lr=self.params['fSLr'],
                      step=self.params['fSstep'], gamma=self.params['fSgamma'])

        return f_S_prime

    def sample_P1(self, f_J_theta):
        f"""
        Functs: - sample from distribution P1
                - first do X_M = J_theta(PA_M), then, regenerate its descentants
                - return the data as [X_S,X_M] format
        """
        X11 = dfTotensor(self.trainDF[['X11']])
        X10 = dfTotensor(self.trainDF[['X10']])

        Y = dfTotensor(self.trainDF[['Y']])
        PA_XM = dfTotensor(self.trainDF[['X6', 'X7', 'X10', 'X11', 'Y']])

        # replace X_M by f_J_theta(PA_XM)
        # pred_XM is a torch.Tensor with shape (B,3)
        pred_XM = f_J_theta(PA_XM)

        pred_X1 = pred_XM[:, 0].unsqueeze(1)
        pred_X8 = pred_XM[:, 1].unsqueeze(1)
        pred_X9 = pred_XM[:, 2].unsqueeze(1)

        # regenerate descents of X1,X8,X9
        pred_X7 = self.f_regen7(pred_X8)
        pred_X6 = self.f_regen6(pred_X7)
        pred_X2 = self.f_regen2(torch.cat([pred_X1, pred_X6, pred_X7, pred_X8, Y], dim=1))
        pred_X4 = self.f_regen4(pred_X2)
        pred_X3 = self.f_regen3(torch.cat([pred_X1, pred_X4], dim=1))
        pred_X5 = self.f_regen5(torch.cat([pred_X1, pred_X3, X10], dim=1))

        # Lasso to filt stable variables
        X_S = torch.cat([pred_X2, pred_X3, pred_X4, pred_X5, pred_X6, pred_X7, X10, X11], dim=1)
        X_M = torch.cat([pred_X1, pred_X8, pred_X9], dim=1)

        return [X_S, X_M, Y]

    def estimate_f_J_theta(self, lassoFilter, f_S_prime):
        r"""
        Functs: - sample from data dtsirbution P1, filt X_S in P1 with Lasso filter
                - predY = f_S_prime(filtedP1)
                - compute negative MSE, and update f_J_theta's parameters
        """
        f_J_theta = NonLinearF(in_dim=5, baseinit=self.params['fJBaseinit'], hidden=8, out_dim=3)
        optimizer = torch.optim.SGD(f_J_theta.parameters(), lr=self.params['fJLr'])
        loss_func = nn.MSELoss()

        num_iters = self.params['fJIter']
        for itera in range(num_iters + 1):
            # sample from P1 distribution
            X_S, X_M, Y = self.sample_P1(f_J_theta)
            filtedX_S = lassoFilter(X_S)
            # negative MSE Loss
            pred_Y = f_S_prime(torch.cat([filtedX_S, X_M], dim=1))
            # we want to maximize this loss, so add a negative
            loss = - loss_func(Y, pred_Y)
            # if self.verbose:
            #    if itera % (num_iters // 5) == 0:
            #        print('iteration: {:d}, loss: {:.5f}'.format(int(itera), - float(loss)))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # fix the trained network
        f_J_theta.trained = True
        for param in f_J_theta.function.parameters():
            param.requires_grad = False

        return f_J_theta

    def estimate(self, ):
        r"""
        Functs: - Lass based opt subset searching
        Warning: - take special care of the input order of Neural Networks
        """
        # init learnable Lasso filter
        lassoFilter = Lasso(p=len(self.S))
        lassoOpt = torch.optim.Adam(lassoFilter.parameters(), lr=self.params['lassoLr'])
        loss_func = nn.MSELoss()

        lambdaa = self.params['lassoLam']
        lassoIters = self.params['lassoIter']

        for lassoIter in range(lassoIters + 1):
            # init and train predY = f_S_prime(sel_X_S_prime.detach(),X_M) on P0
            f_S_prime = self.estimate_f_S_prime(lassoFilter)
            # init f_J_theta, construct P1 by regenerate X_M and its descentant,
            # and train f_J_theta on P1 by MSE(predY,Y)
            f_J_theta = self.estimate_f_J_theta(lassoFilter, f_S_prime)
            # use the trained f_J_theta to generate data again, use the trained f_S_prime to compute the MSE loss again
            # add with L1 norm and update LassoFilter's parameter
            X_S, X_M, Y = self.sample_P1(f_J_theta)
            filtedX_S = lassoFilter(X_S)
            # negative MSE Loss
            pred_Y = f_S_prime(torch.cat([filtedX_S, X_M], dim=1))

            mse = loss_func(Y, pred_Y)
            l1norm = lassoFilter.l1norm()
            loss = mse + lambdaa * l1norm

            loss.backward()
            lassoOpt.step()
            lassoOpt.zero_grad()

            testMSE = self.test(lassoFilter, f_S_prime)
            if self.verbose:
                if lassoIter % (lassoIters // 10) == 0:
                    print('iteration: {:d}, hstar: {:.5f}, L1: {:.3f}, maxMSE: {:.3f}, Omega: {}'.format(int(lassoIter),
                                                                                                         float(mse),
                                                                                                         float(l1norm),
                                                                                                         testMSE,
                                                                                                         lassoFilter.omega.detach().numpy()))
            self.log['iter'].append(lassoIter)
            self.log['hstar'].append(mse.detach().numpy().item())
            self.log['l1norm'].append(l1norm.detach().numpy().item())
            self.log['omega'].append(lassoFilter.omega.detach().numpy().copy())
            self.log['maxmse'].append(testMSE)

            # early stop to save time and avoid collapse
            self.earlystoper(self.log['hstar'][-1])
            if self.earlystoper.early_stop:
                break

        lassoFilter.trained = True
        lassoFilter.omega.requires_grad = False

        self.lassoFilter = lassoFilter
        self.f_S_prime = f_S_prime

        return self.log

    def test(self, lassoFilter, f_S_prime):
        r"""
        Functs: - predict on test-sets and record MSEErrors
        """
        # test
        error_log = list()
        for testDF in self.testDFs:
            filtedX_S = lassoFilter(dfTotensor(testDF[self.S]))
            X_M = dfTotensor(testDF[self.M])
            X_test = torch.cat([filtedX_S, X_M], dim=1)
            Y_test = dfTotensor(testDF[['Y']])

            Y_pred = f_S_prime(X_test)
            mse = torch.mean((Y_test - Y_pred) ** 2).item()

            error_log.append(mse)

        return np.max(error_log)

def Lasso_OptSet(trainsplit, params):
    r"""
    Functs: - a warpper for Lasso based optimal subset selection
    """

    # for 245, 2345: lambda 2, lassoLr 0.1
    # for 2356: lambda 8, lassoLr 0.05
    # for other subsets: lambda 2, lassoLr 0.05

    if trainsplit == '245' or trainsplit =='2345':
        params['lassoLr'] = 0.1
    if trainsplit == '2356':
        params['lassoLam'] = 8

    optset = LassoOptSet(trainsplit, params, verbose=params['verbose'])
    log = optset.estimate()
    return log