
#Acknowledge: https://github.com/facebookresearch/InvarianceUnitTests/blob/main/scripts/models.py

from tkinter import Y
import torch
from torch.autograd import grad

import math

import numpy as np
import torch.nn as nn
import torch.optim as optim
from sklearn import linear_model

from .ERM import ERM


class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)
        # self.linear.weight.data = torch.ones((output_dim, input_dim))
        torch.nn.init.xavier_uniform_(self.linear.weight)
        
    def forward(self, x):
        return self.linear(x)



class HRM(ERM):
    def __init__(self, X, Y, input_dim, output_dim, lam, num_clusters=2, type='regression', device="cpu", task='simulation', arch='linear'):
        
        super(HRM, self).__init__(input_dim=input_dim, output_dim=output_dim, type=type, arch=arch)
        # self.loss = torch.nn.MSELoss()
        self.X, self.y = X, Y
        self.model = LinearModel(input_dim, output_dim)
        
        self.loss = torch.nn.MSELoss()
        
        if self.type == 'regression':
            self.loss = torch.nn.MSELoss()
        elif self.type == 'classification':
             self.loss = torch.nn.BCEWithLogitsLoss()
        else:
            raise Exception('Not Implemented...')
        self.device = device
        self.task = task

        self.mask = None
        
        self.lam = lam
        self.erm_model = None


        front_params = {}
        front_params['num_clusters'] = num_clusters

        back_params = {}
        back_params['input_dim'] = input_dim
        back_params['output_dim'] = output_dim
        back_params['sigma'] = 0.1
        back_params['lam'] = 0.1
        back_params['alpha'] = 1000.0
        back_params['hard_sum'] = 10
        back_params['overall_threshold'] = 0.20
        whole_iters = 5



        self.frontend = McModel(front_params['num_clusters'], self.X, self.y)
        self.backend = MpModel(input_dim=back_params['input_dim'],
                                    output_dim=back_params['output_dim'],
                                    sigma=back_params['sigma'],
                                    lam=back_params['lam'],
                                    alpha=back_params['alpha'],
                                    hard_sum=back_params['hard_sum'])
        self.domains = None
        self.weight = torch.tensor(np.zeros(self.X.shape[1], dtype=np.float32))
        
    def combine_envs(self, envs):
        X = []
        y = []
        for env in envs:
            X.append(env[0])
            y.append(env[1])
        X = torch.cat(X, dim=0)
        y = torch.cat(y, dim=0)
        return X.reshape(-1, X.shape[1]), y.reshape(-1,1)

    def solve(self, iters=5, epochs=1000, lr=1e-3, delta_threshold=250):

        self.density_result = None
        density_record = []
        
        print('err of pre-erm: ', self.loss(self.model(self.X.to(self.device)), self.y.to(self.device)))
        
        for i in range(iters):
            environments, self.domains = self.frontend.cluster(self.weight, self.domains, True, delta_threshold=delta_threshold)
            
            # print(self.domains)
            # print(self.domains)

            # weight, density = self.train(environments, epochs=6000, lr=lr)
            weight, density = self.train(environments, epochs=epochs, lr=lr)
            
            # exit(0)
            
            density = torch.abs(density)
            density[density > 1.] = 1.
            
            density_record.append(density)
            
            self.density_result = density.detach()
            self.weight = density.detach().to(self.device)
            
            print('[{}/{}] Selection Ratio is {}'.format(i, iters, self.weight))
            
        return self
    
    def combine_envs(self, envs):
        X = []
        y = []
        for env in envs:
            X.append(env[0])
            y.append(env[1])
        X = torch.cat(X, dim=0)
        y = torch.cat(y, dim=0)
        return X.reshape(-1, X.shape[1]), y.reshape(-1,1)

    def train(self, envs, epochs=6000, lr=1e-3, renew=False, verbose=False):
        
        print('envs: ')
        envs_torch = []
        for env in envs:
            x, y = env
            envs_torch.append((torch.Tensor(x), torch.Tensor(y)))
            print(x.shape, y.shape)
        envs = envs_torch

        # print('pre-train...')
        
        # self.pretrain(envs)
        # print('set the pre-train model : model <- erm-model')
        # self.model = self.erm_model
        
        print('fine-tuning...')
        
        opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=0.)
        # opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'income':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'insurance':
            opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'mnist':
            opt = torch.optim.Adam(self.model.parameters(), lr)
        
        num_envs = len(envs)
        loss_avg = 0.0
        grad_avg = 0.0
        grad_list = []
        
        erm_curve = []
        penalty_curve = []
        
        for iter in range(epochs):
            
            losses = []
            for x_e , y_e in envs:
                
                x_e = x_e.to(self.device)
                y_e = y_e.to(self.device)
                
                p = range(len(x_e))
                pred = self.model(x_e[p])
                
                loss_e = self.loss(pred, y_e[p].reshape(pred.shape))
                losses.append(loss_e)
                
                # if iter % 10 == 0:
                #     print('losse e :  ', loss_e, 'env: #ins: ', len(x_e))
                
            gradients = [
                grad(loss, self.model.parameters(), create_graph=True)
                for loss in losses
            ]
            
            avg_loss = sum(losses) / num_envs
            avg_gradient = grad(avg_loss, self.model.parameters(), create_graph=True)
            
            # compute trace penalty
            penalty_value = 0
            
            for gradient in gradients:
                for gradient_i, avg_grad_i in zip(gradient, avg_gradient):
                    penalty_value += (gradient_i - avg_grad_i).pow(2).sum()
                
            
            opt.zero_grad()
            (self.lam * avg_loss + penalty_value).backward()
            # (self.lam * avg_loss).backward()
            opt.step()
                
            if verbose:
                if iter % 100 == 0:
                    # print('model: ', self.model.linear.weight.data)
                    print('erm error', self.lam * avg_loss.data.cpu().numpy())
                    print('grad penalty', penalty_value.data.cpu().numpy())
            
            
            erm_curve.append(avg_loss.data.cpu().numpy())
            penalty_curve.append(penalty_value.data.cpu().numpy())
        
        print(' fine-tuning: ')
        print('erm: ', erm_curve[0], ' -> ', erm_curve[-1])
        print('penalty: ', penalty_curve[0], ' -> ', penalty_curve[-1])
        
        return None, self.model.weight.data
        
        # import matplotlib.pyplot as plt
        # plt.plot(range(len(erm_curve)), erm_curve, label='erm')
        # plt.legend()
        # plt.show()
        # plt.plot(range(len(erm_curve)), penalty_curve, label='grad')
        # plt.legend()
        # plt.show()
                    
    '''
    def predict(self, X):
        
        X = torch.Tensor(X).to(self.device)


        X = X * self.mask
        return self.model(X).detach().cpu().numpy()
    '''
    
    def set_model(self, model):
        self.model = model
        if self.device is not None:
            self.model = self.model.to(self.device)
        return self
    def generate_mask(self, the_candidates):
        
        weights = torch.abs(self.model.weight.data).detach().reshape(-1)
        
        weights /= weights.sum()
        sorted_w, idx = torch.sort(weights)
        # print(weights)
        
        pre_sum = torch.cumsum(sorted_w, dim=0)
        
        for the in the_candidates:
            important_idx = torch.where(pre_sum >= (1. - the))
            # print(pre_sum, important_idx)
            mask = torch.zeros(len(pre_sum))
            mask[idx[important_idx]] = 1.
            # print('mask: ', mask)
            print('the: ', the)
            print('select {} from {}'.format(mask.sum(), len(mask)))
        
        return mask
    '''
    def score(self, X, y):
        
        X = torch.Tensor(X).to(self.device)
        y = torch.Tensor(y).to(self.device)
        
        # if self.mask is not None:
        #     X = X * self.mask
        
        X = self.backend.featureSelector(X)
        

        with torch.no_grad():
            pred = self.backend.backmodel(X).detach()
            pred = pred.reshape(-1, self.out_dim)
        
        if self.type == 'classification':
            # Accuracy
            if self.out_dim == 1:
                pred[pred > 0] = 1
                pred[pred < 1] = 0
                score = (pred.eq(y).sum() / len(pred)).item()
            else:
                pred = torch.argmax(pred, dim=1)
                score = (pred.eq(torch.argmax(y, dim=1)).sum() / len(pred)).item()
        elif self.type == 'regression':
            # MSE
            score = ((pred - y.reshape(pred.shape)) ** 2).mean().item()
        else:
            raise Exception('Not Implemented')
        return score
    '''


        
    
        

def compute_penalty(losses, dummy):
    g1 = grad(losses[0::2].mean(), dummy, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy, create_graph=True)[0]
    return (g1 * g2).sum()


def example_1(n=10000, d=2, env=1):
    x = torch.randn(n, d) * env
    y = x + torch.randn(n, d) * env
    z = y + torch.randn(n, d)
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)

if __name__ == '__main__':

    phi = torch.nn.Parameter(torch.ones(4, 1))
    
    fs = torch.nn.Parameter(torch.ones(4, 1))
    
    dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
    opt = torch.optim.SGD([phi, fs], lr=1e-3)
    mse = torch.nn.MSELoss(reduction="none")
    environments = [example_1(env=0.1), example_1(env=1.0)]
    
    
    # irm = IRM(input_dim=4, output_dim=1, lam=1e-5, feature_module=FeatureSelector(4))
    # irm.train(environments, epochs=10000)
    # exit(0)
    
    for iteration in range(100000):
        error = 0
        penalty = 0
        for x_e , y_e in environments:
            # p = torch.randperm(len(x_e))
            p = range(len(x_e))
            
            error_e = mse((x_e[p] * fs.reshape(1, -1)) @ phi * dummy_w, y_e [p])
            penalty += compute_penalty(error_e, dummy_w)
            error += error_e.mean()
        
        opt.zero_grad()
        (1e-5 * error + penalty).backward()
        # print('grad of fs: ', fs.grad)
        # print('grad of phi: ', phi.grad)
        
        opt.step()
        with torch.no_grad():
            fs[fs < 0] = 0
            fs[fs > 1] = 1.
        
        if iteration % 1000 == 0:
            print('feature selection: ', fs)
            print('phi: ', phi)
            print('erm error', 1e-5 * error.data.cpu().numpy())
            print('raw erm error', error.data.cpu().numpy())
            print('grad penalty', penalty.data.cpu().numpy())
            












def pretty(vector):
    if type(vector) is list:
        vlist = vector
    elif type(vector) is np.ndarray:
        vlist = vector.reshape(-1).tolist()
    else:
        vlist = vector.view(-1).tolist()
    return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]"


class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=True)
        self.weight_init()

    def weight_init(self):
        torch.nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, x):
        return self.linear(x)


class WeightedLasso:
    def __init__(self, X, y, weight, lam):
        self.model = LinearRegression(X.shape[1], 1)
        self.X = X
        self.y = y
        self.weight = weight.reshape(-1, 1)
        self.loss = nn.MSELoss()
        self.lam = lam
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

    def train(self):
        self.model.weight_init()
        epochs = 3000

        for epoch in range(epochs):
            self.optimizer.zero_grad()
            pred = self.model(self.X)
            loss = self.loss(pred, self.y) +\
                   self.lam * torch.mean(torch.abs(self.weight*self.model.linear.weight.reshape(self.weight.shape)))
            loss.backward(retain_graph=True)
            self.optimizer.step()
        return self.model.linear.weight.clone().cpu().detach(), self.model.linear.bias.clone().cpu().detach()



class McModel:
    def __init__(self, num_classes, X, y):
        self.num_classes = num_classes
        self.X = X
        self.y = y.reshape(-1, 1)
        self.center = None
        self.bias = None
        self.domain = None
        self.weights = None

    def ols(self):
        for i in range(self.num_classes):
            index = torch.where(self.domain == i)[0]
            tempx = (self.X[index, :]).reshape(-1, self.X.shape[1])
            tempy = (self.y[index, :]).reshape(-1, 1)
            clf = WeightedLasso(tempx, tempy, self.weights, 1.0)
            self.center[i, :], self.bias[i] = clf.train()

    def cluster(self, weight, past_domains, reuse=False, delta_threshold = 250):
        self.center = torch.tensor(np.zeros((self.num_classes, self.X.shape[1]), dtype=np.float32))
        self.bias = torch.tensor(np.zeros(self.num_classes, dtype=np.float32))

        if past_domains is None or not reuse:
            self.domain = torch.tensor(np.random.randint(0, self.num_classes, self.X.shape[0]))
        else:
            self.domain = past_domains
        assert self.domain.shape[0] == self.X.shape[0]
        self.weights = weight

        iter = 0
        end_flag = False

        while not end_flag:
            iter += 1
            self.ols()
            ols_error = []
            for i in range(self.num_classes):
                coef = self.center[i].reshape(-1, 1)
                error = torch.abs(torch.mm(self.X, coef) + self.bias[i] - self.y)
                assert error.shape == (self.X.shape[0], 1)
                ols_error.append(error)
            ols_error = torch.stack(ols_error, dim=0).reshape(self.num_classes, self.X.shape[0])
            new_domain = torch.argmin(ols_error, dim=0)
            assert new_domain.shape[0] == self.X.shape[0]
            diff = self.domain.reshape(-1, 1) - new_domain.reshape(-1, 1)
            diff[diff != 0] = 1
            delta = torch.sum(diff)
            if iter % 10 == 9:
                print("Iter %d | Delta = %d" % (iter, delta))
            if delta <= delta_threshold:
                end_flag = True
            self.domain = new_domain

        environments = []
        for i in range(self.num_classes):
            index = torch.where(self.domain == i)[0]
            tempx = (self.X[index, :]).reshape(-1, self.X.shape[1])
            tempy = (self.y[index, :]).reshape(-1, 1)
            environments.append([tempx, tempy])
        return environments, self.domain







# Feature selection part
class FeatureSelector(nn.Module):
    def __init__(self, input_dim, sigma):
        super(FeatureSelector, self).__init__()
        self.mu = torch.nn.Parameter(0.00 * torch.randn(input_dim, ), requires_grad=True)
        self.noise = torch.randn(self.mu.size())
        self.sigma = sigma
        self.input_dim = input_dim

    def renew(self):
        self.mu = torch.nn.Parameter(0.00 * torch.randn(self.input_dim, ), requires_grad=True)
        self.noise = torch.randn(self.mu.size())

    def forward(self, prev_x):
        z = self.mu + self.sigma * self.noise.normal_() * self.training
        stochastic_gate = self.hard_sigmoid(z)
        new_x = prev_x * stochastic_gate
        return new_x

    def hard_sigmoid(self, x):
        return torch.clamp(x + 0.5, 0.0, 1.0)

    def regularizer(self, x):
        return 0.5 * (1 + torch.erf(x / math.sqrt(2)))

    def _apply(self, fn):
        super(FeatureSelector, self)._apply(fn)
        self.noise = fn(self.noise)
        return self


class MpModel:
    def __init__(self, input_dim, output_dim, sigma=1.0, lam=0.1, alpha=0.5, hard_sum = 1.0, penalty='Ours'):
        self.backmodel = LinearRegression(input_dim, output_dim)
        self.loss = nn.MSELoss()
        self.featureSelector = FeatureSelector(input_dim, sigma)
        self.reg = self.featureSelector.regularizer
        self.lam = lam
        self.mu = self.featureSelector.mu
        self.sigma = self.featureSelector.sigma
        self.alpha = alpha
        self.optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3},
                                     {'params': self.mu, 'lr': 3e-4}])
        self.penalty = penalty
        self.hard_sum = hard_sum
        self.input_dim = input_dim
        self.accumulate_mip_penalty = torch.tensor(np.zeros(10, dtype=np.float32))

    def renew(self):
        self.featureSelector.renew()
        self.mu = self.featureSelector.mu
        self.backmodel.weight_init()
        self.optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3},
                                     {'params': self.mu, 'lr': 3e-4}])


    def combine_envs(self, envs):
        X = []
        y = []
        for env in envs:
            X.append(env[0])
            y.append(env[1])
        X = torch.cat(X, dim=0)
        y = torch.cat(y, dim=0)
        return X.reshape(-1, X.shape[1]), y.reshape(-1,1)

    def pretrain(self, envs, pretrain_epoch=100):

        pre_optimizer = optim.Adam([{'params': self.backmodel.parameters(), 'lr': 1e-3}])
        X, y = self.combine_envs(envs)

        for i in range(pretrain_epoch):
            self.optimizer.zero_grad()
            pred = self.backmodel(X)
            loss = self.loss(pred, y.reshape(pred.shape))
            loss.backward()
            pre_optimizer.step()


    def single_forward(self, x, regularizer_flag = False):
        output_x = self.featureSelector(x)
        if regularizer_flag == True:
            x = output_x.clone().detach()
        else:
            x = output_x
        return self.backmodel(x)


    def single_iter_mip(self, envs):
        assert type(envs) == list
        num_envs = len(envs)
        loss_avg = 0.0
        grad_avg = 0.0
        grad_list = []
        for [x,y] in envs:
            pred = self.single_forward(x)
            loss = self.loss(pred, y.reshape(pred.shape))
            loss_avg += loss/num_envs

        for [x,y] in envs:
            pred = self.single_forward(x, True)
            loss = self.loss(pred, y.reshape(pred.shape))
            grad_single = grad(loss, self.backmodel.parameters(), create_graph=True)[0].reshape(-1)
            grad_avg += grad_single / num_envs
            grad_list.append(grad_single)

        penalty = torch.tensor(np.zeros(self.input_dim, dtype=np.float32))
        for gradient in grad_list:
            penalty += (gradient - grad_avg)**2

        penalty_detach = torch.sum(penalty.reshape(self.mu.shape)*(self.mu+0.5))
        reg = torch.sum(self.reg((self.mu + 0.5) / self.sigma))
        reg = (reg-self.hard_sum)**2
        total_loss = loss_avg + self.alpha * (penalty_detach)
        total_loss = total_loss + self.lam * reg
        return total_loss, penalty_detach, self.reg((self.mu + 0.5) / self.sigma)


    def get_gates(self):
        return pretty(self.mu+0.5)

    def get_paras(self):
        return pretty(self.backmodel.linear.weight)

    def train(self, envs, epochs=6000):
        self.renew()
        self.pretrain(envs, 3000)
        for epoch in range(1,epochs+1):
            self.optimizer.zero_grad()
            loss, penalty, reg = self.single_iter_mip(envs)
            loss.backward()
            self.optimizer.step()
            if epoch % epochs == 0:
                print("Epoch %d | Loss = %.4f | Gates %s | Theta = %s" %
                      (epoch, loss, self.get_gates(), pretty(self.backmodel.linear.weight)))
        return self.mu + 0.5, reg

def combine_envs(envs):
    X = []
    y = []
    for env in envs:
        X.append(env[0])
        y.append(env[1])
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)
    return X.reshape(-1, X.shape[1]), y.reshape(-1, 1)

