import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import json
from datetime import datetime
import copy
from itertools import combinations,cycle
import scipy.stats as stats
import random

def plot(recorder, save=True, date_time=None):
    BASE = '/home/anonymous/data/CausallyInvariant_output/ADNIB2FAQ/FindOptSets_results/ours'
    if save:
        if date_time == None:
            date_time = datetime.now().strftime("%Y%m%d_%H%M%S")[2:]
        save_path = os.path.join(BASE, date_time)
        os.makedirs(save_path, exist_ok=True)
        print(save_path)
        # save recorder as json
        recorderBack = copy.deepcopy(recorder)
        json_dict = recorderBack
        for key in json_dict:
            json_dict[key]['h_stars'] = [i.tolist() for i in json_dict[key]['h_stars']]

        json_str = json.dumps(json_dict, indent=4)
        with open(os.path.join(save_path, 'recorder.json'), 'w') as json_file:
            json_file.write(json_str)

    # the spearman correlation graph
    HSTAR = list()
    TESTERR = list()

    for ind, S_prime in enumerate(recorder.keys()):
        h_stars = np.concatenate(recorder[S_prime]['h_stars'], axis=1).T
        # avg over trainsplits, and take value at the last iteration
        mean_hstar = np.mean(h_stars, axis=0)[-1]
        HSTAR.append((ind, mean_hstar))

        test_errors = recorder[S_prime]['test_errors']
        mean_testerr = np.mean(test_errors)
        TESTERR.append((ind, mean_testerr))

    # sort S_prime in h* and testError by their values
    HSTAR.sort(key=lambda tup: tup[1])

    # order testerr by hstar
    testerr_hstar = [TESTERR[tup[0]][1] for tup in HSTAR]

    # Note: the order of the below code and the above code MUST NOT be changed
    TESTERR.sort(key=lambda tup: tup[1])
    spearman = stats.spearmanr([tup[0] for tup in HSTAR], [tup[0] for tup in TESTERR])[0]

    plt.figure()
    plt.plot(list(range(len(HSTAR))), testerr_hstar)

    plt.xlabel(r'$f_S$ in the order of increasing $h*(S)$')
    plt.ylabel('Max Test MSE')
    if save:
        plt.savefig(os.path.join(save_path, 'spearman.pdf'), bbox_inches='tight')



def dfTotensor(df):
    r"""
    Functs: - given a DataFrame, convert it into torch Tensor
    """
    return (torch.from_numpy(df.values)).float()


def print_nnmodule(f):
    r"""
    Functs: - print weight and bias of a nn.module
    """
    for m in f.function:
        if isinstance(m, nn.Linear):
            print(m.weight)
            print(m.bias)


class NonLinearF():
    r"""
    Functs: - define a non-linear fitting function
            - FC(in_dim,hidden)
            - Sigmoid()
            - FC(hidden,out_dim)
            - basically, hidden=5, baseinit=True, Sigmoid non-lieanr is a okay setting

    Args: - both fit() and predict() take df.DataFrame() as input and output (if any) a np.array()
    """

    def __init__(self, in_dim, baseinit, hidden=1, out_dim=1, verbose=False):
        super(NonLinearF, self).__init__()

        self.in_dim = in_dim
        modules = [nn.Linear(in_dim, hidden),
                   nn.Sigmoid(),
                   nn.Linear(hidden, out_dim)]

        self.function = nn.Sequential(*modules)
        self.trained = False
        self.verbose = verbose

        if baseinit:
            for m in self.function:
                if isinstance(m, nn.Linear):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def fit(self, covariates, target, num_iters=2000, lr=0.1, step=None, gamma=None):
        r"""
        Functs: - train the model in covariates~target
        """
        covariates = dfTotensor(covariates)
        target = dfTotensor(target)

        optimizer = torch.optim.SGD(self.function.parameters(), lr=lr)
        if step is not None:
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma)

        loss_func = nn.MSELoss()

        for itera in range(num_iters + 1):
            prediction = self.function(covariates)
            loss = loss_func(prediction, target)
            if self.verbose:
                if itera % (num_iters // 10) == 0:
                    print('iteration: {:d}, loss: {:.3f}, lr: {:.4f}'.format(int(itera), float(loss),
                                                                             optimizer.state_dict()['param_groups'][0][
                                                                                 'lr']))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if step is not None:
                scheduler.step()

        self.trained = True
        for param in self.function.parameters():
            param.requires_grad = False

    def predict(self, covariates):
        r"""
        Args: - covariates with shape B,in_dim
              - output with shape B,out_dim
        """
        covariates = dfTotensor(covariates)
        assert self.trained, 'NonLinearF must be trained befored prediction'
        assert covariates.shape[1] == self.in_dim, 'Mistake input dim for Non-Linear F, expect: {}, get: {}'.format(
            self.in_dim,
            covariates.shape[1])
        return self.function(covariates).detach().numpy()

    def simplepredict(self, covariates):
        r"""
        Functs: - a simple version of predict
                - accept torch.Tensor as input and return torch.Tensor
        """
        assert self.trained, 'NonLinearF must be trained befored prediction'

        return self.function(covariates)


class SimpleNonLinearF(nn.Module):
    r"""
    Functs: - a simple version of non-linear function
            - it does NOT have .fit() or .predict()
            - just define non-linear FCs and a forward() function
            - this is specically prepared for the estimation of f_J_theta
    """

    def __init__(self, in_dim, baseinit, hidden, out_dim):
        super(SimpleNonLinearF, self).__init__()

        self.in_dim = in_dim
        modules = [nn.Linear(in_dim, hidden),
                   nn.Sigmoid(),
                   nn.Linear(hidden, out_dim)]

        self.function = nn.Sequential(*modules)
        self.trained = False

        if baseinit:
            for m in self.function:
                if isinstance(m, nn.Linear):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, covariates):
        return self.function(covariates)


class OptNodeSets():
    r"""
    Under the graphical assumption of G4567_X10 (@ denotes changed due to X10->X9, X10->X5)

    Functs: - For a given training split, e.g. 1456,
                1. read from BASE/1456.csv
                2. for a given S' in S_all
                    - Estimation of f_S'
                        1. in the original distribution, learn the X_i=f(PA_i) where X_i is the descent of X_M in the induced graph
                           this include:
                               - X_2=f(X_1,X_6,X_7,X_8,Y)
                        2. shffle @X_do*={X_1,X_3,X_4,X_5,X_6,X_7,X_8,X_9} and regenerate X_2 by f
                        3. train Y = f_S'(X_S',do(X_M)) in the regenereted samples, where X_M={X_1,X_8,X_9}
                    - Estimation of h*(S')
                        1. generate samples from P(J_theta)
                             - replace X_8 by J_theta(PA_8 = X_11,Y)
                               replace X_1 by J_theta(PA_1 = X_6)
                               replace @X_9 by J_theta(PA_9 = X_7,X_10)
                             - regenerate X_i by f(PA_i) for X_i is a descent of X_1,X_8,X_9 in the induced graph
                               this include:
                                   - X_7 = f(X_8)
                                   - X_6 = f(X_7)
                                   - X_2 = f(X_1,X_6,X_7,X_8,Y)
                                   - X_4 = f(X_2)
                                   - X_3 = f(X_1,X_4)
                                   - @X_5 = f(X_1,X_3,X_10)
                                   - @X_10 is not a descent in the induced graph anymore
                        2. calculate Y_hat = f_S'(X_S',do(X_M)), where X_S',X_M~P(J_theta)
                        3. compute negMSELoss -||Y-Y_hat|| and optimize over \theta
                3. for the given S' and the trained f_S'
                    - read 3/4 test-sets from BASE/1456/X.csv
                    - Y_hat = f_S*(X_S*,X_M)
                    - compute maxMSEError ||Y-Y_hat|| among the 3/4 test-sets
                4. Return [key=S']: {negMSELoss, maxMSEError}
            - Theoratically, in every training split, those S' with minimum negMSELoss should also have minimum maxMSEError

    Notes:  - We need to estimate 7 regen functions, i.e. f_regen2,3,4,5,6,7,11
            - We need to estimate 1 f_S_prime and use 3 J_theta functions, i.e. f_J_theta1,8,9
    """

    def __init__(self, trainsplit, params, seed=1234, need_norm=True):
        r"""
        Args: - trainsplit, a str, e.g. 1456
              - S_prime, a list of str, e.g. ['X9','X2']
        """
        BASE = '/home/anonymous/data/CausallyInvariant_output/ADNIB2FAQ/FindOptSets/'
        self.trainsplit = trainsplit
        self.params = params
        self.seed = seed

        trainfilename = os.path.join(BASE, '{}.csv'.format(trainsplit))
        self.trainDF = pd.read_csv(trainfilename)

        self.need_norm = need_norm
        if self.need_norm:
            for var in ['X{}'.format(ind + 1) for ind in range(11)] + ['Y']:
                mean = self.trainDF[[var]].mean().values[0]
                std = self.trainDF[[var]].std().values[0]
                self.trainDF[[var]] = (self.trainDF[[var]] - mean) / std

        self.testfolder = os.path.join(BASE, '{}'.format(trainsplit))

        # init NonLinearFs
        # we need to estimate 7 f_regens, i.e. 2,3,4,5,6,7,11
        self.f_regen7 = NonLinearF(in_dim=1, baseinit=True, hidden=1)
        self.f_regen6 = NonLinearF(in_dim=1, baseinit=True, hidden=1)
        self.f_regen2 = NonLinearF(in_dim=5, baseinit=True, hidden=5)
        self.f_regen4 = NonLinearF(in_dim=1, baseinit=True, hidden=1)
        self.f_regen3 = NonLinearF(in_dim=2, baseinit=True, hidden=1)
        self.f_regen5 = NonLinearF(in_dim=3, baseinit=True, hidden=1)
        # @self.f_regen10 = NonLinearF(in_dim=2, baseinit=True, hidden=1)

        print('Train split: {}'.format(self.trainsplit))
        self.estimate_f_regen7()
        self.estimate_f_regen6()
        self.estimate_f_regen2()
        self.estimate_f_regen4()
        self.estimate_f_regen3()
        self.estimate_f_regen5()
        # @ self.estimate_f_regen10()

    def estimate_f_regen7(self, ):
        r"""
        Functs: - learn X_7 = self.f_regen7(X_8)
        """
        # print('Estimating f_regen7 ...')

        X = self.trainDF[['X8']]
        Y = self.trainDF[['X7']]

        self.f_regen7.fit(X, Y, num_iters=2000, lr=0.1)

    def estimate_f_regen6(self, ):
        r"""
        Functs: - learn X_6 = self.f_regen6(X_7)
        """
        # print('Estimating f_regen6 ...')

        X = self.trainDF[['X7']]
        Y = self.trainDF[['X6']]

        self.f_regen6.fit(X, Y, num_iters=2000, lr=0.1)

    def estimate_f_regen2(self, ):
        r"""
        Functs: - learn X_2 = self.f_regen2(X_1,X_6,X_7,X_8,Y)
        """
        # print('Estimating f_regen2 ...')

        X = self.trainDF[['X1', 'X6', 'X7', 'X8', 'Y']]
        Y = self.trainDF[['X2']]

        self.f_regen2.fit(X, Y, num_iters=5000, lr=0.1)

    def estimate_f_regen4(self, ):
        r"""
        Functs: - learn X_4 = self.f_regen4(X_2)
        """
        # print('Estimating f_regen4 ...')

        X = self.trainDF[['X2']]
        Y = self.trainDF[['X4']]

        self.f_regen4.fit(X, Y, num_iters=2000, lr=0.1)

    def estimate_f_regen3(self, ):
        r"""
        Functs: - learn X_3 = self.f_regen3(X_1,X_4)
        """
        # print('Estimating f_regen3 ...')

        X = self.trainDF[['X1', 'X4']]
        Y = self.trainDF[['X3']]

        self.f_regen3.fit(X, Y, num_iters=5000, lr=0.1)

    def estimate_f_regen5(self, ):
        r"""
        Functs: - learn X_5 = self.f_regen5(X_1,X_3)
        """
        # print('Estimating f_regen5 ...')

        X = self.trainDF[['X1', 'X3', 'X10']]
        Y = self.trainDF[['X5']]

        self.f_regen5.fit(X, Y, num_iters=2000, lr=0.1)

    def estimate_f_regen10(self, ):
        r"""
        Functs: - learn X_10 = self.f_regen10(X_5,X_9)

        Notes: - large train/test error detected ~ 0.9
               - reason for such a big error may because the fact that X10 is gender,
                 a binary variable that should not be decided by X9,X5
        """
        # print('Estimating f_regen10 ...')

        X = self.trainDF[['X5', 'X9']]
        Y = self.trainDF[['X10']]

        self.f_regen10.fit(X, Y, num_iters=5000, lr=0.1)

    def estimate_f_S_prime(self, ):
        r"""
        Functs: - sample from p* by shuffle X_do*=X1,X3,X4,X5,X6,X7,X8,X9,X10
                - regenerate X2 by f_regen2(X1,X6,X7,X8,Y)
                - train Y=f_S_prime(X_S_prime,do X_M) in p*
        """
        print('Estimating f_S_prime ...')

        # shuffle @X_do*=X1,X3,X4,X5,X6,X7,X8,X9
        shufTrainDF = self.trainDF.copy()
        for variable in ['X1', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8', 'X9']:
            shufTrainDF.loc[:, variable] = shufTrainDF.loc[:, variable].sample(frac=1, random_state=self.seed).values

        # regenerate X2 by f_regen2(X1,X6,X7,X8,Y)
        X = shufTrainDF.loc[:, ['X1', 'X6', 'X7', 'X8', 'Y']]
        pred = self.f_regen2.predict(X)

        shufTrainDF.loc[:, 'X2'] = pred

        # train Y=f_S_prime(X_S',do X1,X8,X9)
        XX = shufTrainDF[self.S_prime + ['X1', 'X8', 'X9']]
        YY = shufTrainDF[['Y']]

        self.f_S_prime.fit(XX, YY, num_iters=self.params['fSiters'], lr=self.params['fSlr'],
                           step=self.params['fSstep'], gamma=self.params['fSgamma'])

    def estimate_hstar_S_prime(self):
        r"""
        Functs: - estimate h*(S')
                1. generate samples from P(J_theta)
                    - replace X1,X8,X9 by f_J_theta(X6,X7,X11,Y)
                    - regenerate their descent by f_regen7,6,2,4,3,5,10
                2. calculate Y_hat = f_S_prime(X_S', do X1,X8,X9)
                3. compute negMSELoss = -||Y-Y_hat|| and optimize over \theta
        """
        print('Estimating hstar_S_prime ...')

        # sample from P(J_theta)
        X11 = dfTotensor(self.trainDF[['X11']])
        X10 = dfTotensor(self.trainDF[['X10']])

        Y = dfTotensor(self.trainDF[['Y']])
        PA_XM = dfTotensor(self.trainDF[['X6', 'X7', 'X10', 'X11', 'Y']])

        num_iters = self.params['fJiters']
        optimizer = torch.optim.SGD(self.f_J_theta.parameters(), lr=self.params['fJlr'])
        loss_func = nn.MSELoss()
        loss_log = list()

        for itera in range(num_iters + 1):
            # replace X_M by f_J_theta(PA_XM)
            # pred_XM is a torch.Tensor with shape (B,3)
            pred_XM = self.f_J_theta(PA_XM)

            pred_X1 = pred_XM[:, 0].unsqueeze(1)
            pred_X8 = pred_XM[:, 1].unsqueeze(1)
            pred_X9 = pred_XM[:, 2].unsqueeze(1)

            # regenerate descents of X1,X8,X9
            pred_X7 = self.f_regen7.simplepredict(pred_X8)
            pred_X6 = self.f_regen6.simplepredict(pred_X7)

            pred_X2 = self.f_regen2.simplepredict(torch.cat([pred_X1, pred_X6, pred_X7, pred_X8, Y], dim=1))

            pred_X4 = self.f_regen4.simplepredict(pred_X2)
            pred_X3 = self.f_regen3.simplepredict(torch.cat([pred_X1, pred_X4], dim=1))

            pred_X5 = self.f_regen5.simplepredict(torch.cat([pred_X1, pred_X3, X10], dim=1))
            # pred_X10 = self.f_regen10.simplepredict(torch.cat([pred_X5, pred_X9], dim=1))

            # predict Y by f_S_prime
            X_S_all = {'X2': pred_X2, 'X3': pred_X3, 'X4': pred_X4, 'X5': pred_X5,
                       'X6': pred_X6, 'X7': pred_X7, 'X10': X10, 'X11': X11}

            X_S_prime_X_M = list()
            for S in self.S_prime:
                X_S_prime_X_M.append(X_S_all[S])

            X_S_prime_X_M.append(pred_X1)
            X_S_prime_X_M.append(pred_X8)
            X_S_prime_X_M.append(pred_X9)

            # prediction
            pred_Y = self.f_S_prime.simplepredict(torch.cat(X_S_prime_X_M, dim=1))

            # we want to maximize the loss, so add a negative
            loss = - loss_func(Y, pred_Y)
            if itera % (num_iters // 5) == 0:
                # print('iteration: {:d}, loss: {:.3f}'.format(int(itera), - float(loss)))
                loss_log.append(- float(loss.detach()))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        self.f_J_theta.trained = True
        for param in self.f_J_theta.function.parameters():
            param.requires_grad = False

        return loss_log

    def estimate(self, S_prime):
        r"""
        Functs: - since f_regen7,6,2,4,3,5,10 keep unchanged in the same trainsplit when S_prime changes
                - we estimate them in the model ininitialization and set their require_grad=False
                - in this function, for a given S_prime (a list of str), we estimate f_S_prime and f_J_theta
        """

        assert 'X1' not in S_prime, 'Mistake: S_prime must not contain any X_M'
        assert 'X8' not in S_prime, 'Mistake: S_prime must not contain any X_M'
        assert 'X9' not in S_prime, 'Mistake: S_prime must not contain any X_M'

        print('Train_split: {}, S_prime: {}'.format(self.trainsplit, ','.join(S_prime)))
        self.S_prime = S_prime

        # Y = f_S_prime(X_S_prime,do X_M), and X_M contains 3 variables X_1,X_8,X_9
        self.f_S_prime = NonLinearF(in_dim=len(S_prime) + 3, baseinit=True, hidden=5, verbose=False)

        # we try optimize over one f_J_theta first
        # that is, we optimizer over X1,X8,X9 = @f_J_theta(X6,X7,X10,X11,Y)
        self.f_J_theta = SimpleNonLinearF(in_dim=5, baseinit=True, hidden=8, out_dim=3)

        self.estimate_f_S_prime()
        negMSELosses = self.estimate_hstar_S_prime()

        return negMSELosses

    def test(self, S_prime):
        r"""
        Functs: - after estimation on a given S_prime
                - predict on test-sets with the trained f_S_prime
                - record maxMSEErrors and return
        """
        assert S_prime == self.S_prime

        assert self.f_S_prime.trained
        assert self.f_J_theta.trained

        # test
        error_log = list()

        for filename in os.listdir(self.testfolder):
            testDF = pd.read_csv(os.path.join(self.testfolder, filename))

            if self.need_norm:
                # normalization
                for var in ['X{}'.format(ind + 1) for ind in range(11)] + ['Y']:
                    mean = testDF[[var]].mean().values[0]
                    std = testDF[[var]].std().values[0]
                    testDF[[var]] = (testDF[[var]] - mean) / std

            X_test = dfTotensor(testDF[self.S_prime + ['X1', 'X8', 'X9']])
            Y_test = dfTotensor(testDF[['Y']])

            Y_pred = self.f_S_prime.simplepredict(X_test)
            mse = torch.mean((Y_test - Y_pred) ** 2).item()

            error_log.append(mse)

        return np.max(error_log)


def Find_OptSets(trainsplit):
    r"""
    Functs: - given a trainsplit (which is a str)
            - create a OptNodeSets object and search for the h* of every S_prime in this trainsplit
            - this function return a recorder, which is a dict, recording h* and maxTestErrors
    """
    # generate power sets of stable nodes
    stableNodes = ['X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X10', 'X11']
    power_sets = list()

    for n in range(len(stableNodes)):
        for sett in list(combinations(stableNodes, n + 1)):
            power_sets.append(list(sett))

    # if use all possible S_prime, use power_sets (2^n); else choose what you need
    #S_prime_all = [power_sets[ind] for ind in [4, 23, 54, 89, 121, 150, 176, 201, 213, 254]]
    S_prime_all = power_sets

    #S_prime_all = [['X4','X6']]

    # record the results
    recorder = dict()
    for S_prime in S_prime_all:
        save_name = ','.join(S_prime) if len(S_prime) >0 else 'empty'
        recorder[save_name] = dict()
        recorder[save_name]['h_stars'] = list()
        recorder[save_name]['test_errors'] = list()

    params = {'fSlr': 0.25, 'fSiters': 5000, 'fSstep': 4000, 'fSgamma': 0.4,
              'fJlr': 0.25, 'fJiters': 2000}

    findoptsets = OptNodeSets(trainsplit, params, need_norm=True)

    for S_prime in S_prime_all:
        save_name = ','.join(S_prime) if len(S_prime) >0 else 'empty'
        h_star = findoptsets.estimate(S_prime)
        recorder[save_name]['h_stars'].append(np.array(h_star)[:, np.newaxis])
        test_error = findoptsets.test(S_prime)
        recorder[save_name]['test_errors'].append(test_error)

    return recorder