"""This is directly from https://arxiv.org/pdf/1907.02893.pdf"""

import torch
from torch._C import device
from torch.autograd import grad

from sklearn import linear_model

from .ERM import ERM


class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)
        # self.linear.weight.data = torch.ones((output_dim, input_dim))
        torch.nn.init.xavier_uniform_(self.linear.weight)
        
    def forward(self, x):
        return self.linear(x)



class IRM(ERM):
    def __init__(self, input_dim, output_dim, lam, type='regression', arch='linear', task='simulation', device="cpu"):
        
        super(IRM, self).__init__(input_dim=input_dim, output_dim=output_dim, type=type, arch=arch)
        
        # self.feature_module = feature_module
        
        self.type = type
        
        # self.loss = torch.nn.MSELoss(reduction="none")
        # self.loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        
        if self.type == 'regression':
            self.loss = torch.nn.MSELoss(reduction="none")
        elif self.type == 'classification':
             self.loss = torch.nn.BCEWithLogitsLoss(reduction="none")
        else:
            raise Exception('Not Implemented...')
        
        self.model = LinearModel(input_dim, output_dim)
        
        self.device = device
    
        self.lam = lam
        self.task = task
        
        
    def combine_envs(self, envs):
        X = []
        y = []
        for env in envs:
            X.append(env[0])
            y.append(env[1])
        X = torch.cat(X, dim=0)
        y = torch.cat(y, dim=0)
        return X.reshape(-1, X.shape[1]), y.reshape(-1,1)
    
    def compute_penalty(self, losses, dummy):
        g1 = grad(losses[0::2].mean(), dummy, create_graph=True)[0]
        g2 = grad(losses[1::2].mean(), dummy, create_graph=True)[0]
        
        return (g1 * g2).sum()
        
    def train(self, envs, epochs, lr=1e-3, renew=False, verbose=False):
        
        envs_torch = []
        for env in envs:
            x, y = env
            envs_torch.append((torch.Tensor(x), torch.Tensor(y)))
        envs = envs_torch
            
        print('fine-tuning...')
        
        opt = torch.optim.SGD([{'params': self.model.parameters(), 'lr': lr}], weight_decay=1.)
        if self.task == 'mnist':
            opt = torch.optim.Adam(self.model.parameters(), lr)
        
        dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(self.device)
        
        erm_curve = []
        grd_curve = []
        
        for iter in range(epochs):
            error = 0
            penalty = 0
            for x_e , y_e in envs:
                
                # print('data: ', x_e.shape, y_e.shape)
                # p = torch.randperm(len(x_e))
                
                x_e = x_e.to(self.device)
                y_e = y_e.to(self.device)
                
                p = range(len(x_e))
                pred = self.model(x_e[p])
                
                # print(pred.shape, y_e[p].shape)
                
                error_e = self.loss((pred * dummy_w).reshape(y_e[p].shape), y_e[p])
                penalty += self.compute_penalty(error_e, dummy_w)
                error += error_e.mean()
                if verbose:
                    if iter % 100 == 0:
                        # print('env-i (mse): ', error_e.mean())
                        pass
                
            opt.zero_grad()
            (self.lam * error + penalty).backward()
            # error.backward()
            opt.step()
                
            if verbose:
                if iter % 100 == 0:
                    # print('model: ', self.model.linear.weight.data)
                    print('erm error', self.lam * error.data.cpu().numpy())
                    print('grad penalty', penalty.data.cpu().numpy())
            
            
            erm_curve.append(error.data.cpu().numpy())
            grd_curve.append(penalty.data.cpu().numpy())
        
        print(' fine-tuning: ')
        print('erm: ', erm_curve[0], ' -> ', erm_curve[-1])
        print('grad: ', grd_curve[0], ' -> ', grd_curve[-1])
        
        # import matplotlib.pyplot as plt
        # plt.plot(range(len(erm_curve)), erm_curve, label='erm')
        # plt.legend()
        # plt.show()
        # plt.plot(range(len(erm_curve)), grd_curve, label='grad')
        # plt.legend()
        # plt.show()
        
        
    def predict(self, X):
        
        X = torch.Tensor(X).to(self.device)
        return self.model(X).detach().cpu().numpy()

    def set_model(self, model):
        self.model = model
        return self
    
    def generate_mask(self, the_candidates):
        
        
        if self.arch == 'linear':
            
            weights = torch.abs(self.model.weight.data).detach().reshape(-1)
        else:
            weights = torch.abs(self.model.fea_module[0].weight.data).detach().sum(dim=0).reshape(-1)
                
        weights /= weights.sum()
        sorted_w, idx = torch.sort(weights)
        # print(weights)
        
        pre_sum = torch.cumsum(sorted_w, dim=0)
        
        for the in the_candidates:
            important_idx = torch.where(pre_sum >= (1. - the))
            # print(pre_sum, important_idx)
            mask = torch.zeros(len(pre_sum))
            mask[idx[important_idx]] = 1.
            # print('mask: ', mask)
            print('the: ', the)
            print('select {} from {}'.format(mask.sum(), len(mask)))
        
        return mask
    
        

def compute_penalty(losses, dummy):
    g1 = grad(losses[0::2].mean(), dummy, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy, create_graph=True)[0]
    return (g1 * g2).sum()


def example_1(n=10000, d=2, env=1):
    x = torch.randn(n, d) * env
    y = x + torch.randn(n, d) * env
    z = y + torch.randn(n, d)
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)

if __name__ == '__main__':

    phi = torch.nn.Parameter(torch.ones(4, 1))
    
    fs = torch.nn.Parameter(torch.ones(4, 1))
    
    dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
    opt = torch.optim.SGD([phi, fs], lr=1e-3)
    mse = torch.nn.MSELoss(reduction="none")
    environments = [example_1(env=0.1), example_1(env=1.0)]
    
    
    # irm = IRM(input_dim=4, output_dim=1, lam=1e-5, feature_module=FeatureSelector(4))
    # irm.train(environments, epochs=10000)
    # exit(0)
    
    for iteration in range(100000):
        error = 0
        penalty = 0
        for x_e , y_e in environments:
            # p = torch.randperm(len(x_e))
            p = range(len(x_e))
            
            error_e = mse((x_e[p] * fs.reshape(1, -1)) @ phi * dummy_w, y_e [p])
            penalty += compute_penalty(error_e, dummy_w)
            error += error_e.mean()
        
        opt.zero_grad()
        (1e-5 * error + penalty).backward()
        # print('grad of fs: ', fs.grad)
        # print('grad of phi: ', phi.grad)
        
        opt.step()
        with torch.no_grad():
            fs[fs < 0] = 0
            fs[fs > 1] = 1.
        
        if iteration % 1000 == 0:
            print('feature selection: ', fs)
            print('phi: ', phi)
            print('erm error', 1e-5 * error.data.cpu().numpy())
            print('raw erm error', error.data.cpu().numpy())
            print('grad penalty', penalty.data.cpu().numpy())
            
    