import numpy as np
import torch
from torch.utils.data import DataLoader
from .dynamic_net import Drnet, TR
from cauml.utils import set_seed, cat

# criterion
def criterion(out, y, alpha=0.5, epsilon=1e-6):
    return ((out[1].squeeze() - y.squeeze())**2).mean() - alpha * torch.log(out[0] + epsilon).mean()

def criterion_TR(out, trg, y, beta=1., epsilon=1e-6):
    # out[1] is Q
    # out[0] is g
    return beta * ((y.squeeze() - trg.squeeze()/(out[0].squeeze() + epsilon) - out[1].squeeze())**2).mean()

class TarNet(object):
    def __init__(self) -> None:
        self.config = {
                    'methodName': 'TarNet',
                    'device': 'mps',
                    'epochs': 30,
                    'verbose': 10,
                    'batch_size': 500,
                    'shuffle': 1,
                    'wd': 5e-3,
                    'tr_wd': 5e-3,
                    'momentum': 0.9,
                    'cfg_density': [(82, 120, 1, 'relu'), (120, 100, 1, 'relu')],
                    'num_grid': 10,
                    'cfg': [(100, 100, 1, 'relu'), (100, 1, 1, 'id')],
                    'isenhance': 0,
                    'isTargetReg': 1, #1
                    'init_lr': 0.01, #0.01
                    'alpha': 0.5,
                    'tr_init_lr': 0.001,
                    'beta': 1.0,
                    'tr_knots': list(np.arange(0.1, 1, 0.1)),
                    'tr_degree': 2,
                    'seed': 2022,   
                    }

    def set_Configuration(self, config):
        self.config = config

    def fit(self, data, exp=-1, config=None):
        if config is None:
            config = self.config

        # optimizer
        epochs = config['epochs']
        verbose = config['verbose']
        batch_size = config['batch_size']
        shuffle = config['shuffle']
        device = config['device']
        wd = config['wd']
        tr_wd = config['tr_wd']
        momentum = config['momentum']

        cfg_density = config['cfg_density']
        num_grid = config['num_grid']
        cfg = config['cfg']
        isenhance = config['isenhance']
        isTargetReg = config['isTargetReg']
        init_lr = config['init_lr']
        alpha = config['alpha']
        tr_init_lr = config['tr_init_lr']
        beta = config['beta']
        tr_knots = config['tr_knots']
        tr_degree = config['tr_degree']

        self.isTargetReg = isTargetReg

        set_seed(config['seed'])
        data.tensor()
        data.to(device)
        self.data = data
        train_loader = DataLoader(data, batch_size=batch_size, shuffle=shuffle)


        #### Build Network ####
        model = Drnet(cfg_density, num_grid, cfg, isenhance=isenhance)
        model._initialize_weights()
        model.to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=wd, nesterov=True)

        if isTargetReg:
            TargetReg = TR(tr_degree, tr_knots)
            TargetReg._initialize_weights()
            TargetReg.to(device)
            tr_optimizer = torch.optim.SGD(TargetReg.parameters(), lr=tr_init_lr, weight_decay=tr_wd)

        for epoch in range(epochs):
            for idx, inputs in enumerate(train_loader):
                t = inputs['t'].to(device)
                x = inputs['x'].to(device)
                y = inputs['y'].to(device)

                if isTargetReg:
                    optimizer.zero_grad()
                    out = model.forward(t, x)
                    trg = TargetReg(t)
                    loss = criterion(out, y, alpha=alpha) + criterion_TR(out, trg, y, beta=beta)
                    loss.backward()
                    optimizer.step()

                    tr_optimizer.zero_grad()
                    out = model.forward(t, x)
                    trg = TargetReg(t)
                    tr_loss = criterion_TR(out, trg, y, beta=beta)
                    tr_loss.backward()
                    tr_optimizer.step()
                else:
                    optimizer.zero_grad()
                    out = model.forward(t, x)
                    loss = criterion(out, y, alpha=alpha)
                    loss.backward()
                    optimizer.step()

            if epoch % verbose == 0:
                print('current epoch: ', epoch)
                print('loss: ', loss.data)

        self.model = model
        if isTargetReg: self.TargetReg = TargetReg

    def predict(self, data=None, t=None, x=None):
        if data is None:
            data = self.data.test

        if x is None:
            x = data.x

        if t is None:
            t = data.t

        if self.isTargetReg:
            out = self.model.forward(t, x)
            tr_out = self.TargetReg(t).data
            g = out[0].data.squeeze()
            out = out[1].data.squeeze() + tr_out / (g + 1e-6)
        else:
            out = self.model.forward(t, x)
            out = out[1].data.squeeze()
        
        return out.reshape(-1,1).detach().cpu().numpy()

    def predict(self, data=None, t=None, x=None):
        if t is None:
            if data is None :
                data = self.data.test

        if x is None:
            x = data.x

        if t is None:
            t = data.t

        if self.isTargetReg:
            out = self.model.forward(t, x)
            tr_out = self.TargetReg(t).data
            g = out[0].data.squeeze()
            out = out[1].data.squeeze() + tr_out / (g + 1e-6)
        else:
            out = self.model.forward(t, x)
            out = out[1].data.squeeze()
        
        return out.reshape(-1,1).detach().cpu().numpy()

    def ITE(self, data=None, t=None, x=None):
        if data is None:
            data = self.data.test

        if x is None:
            x = data.x

        if t is None:
            t = data.t

        ITE_0 = self.predict(t=t-t, x=x)
        ITE_1 = self.predict(t=t-t+1, x=x)
        ITE_t = self.predict(t=t, x=x)

        return ITE_0,ITE_1,ITE_t

    def ATE(self, data=None, t=None, x=None):
        ITE_0,ITE_1,ITE_t = self.ITE(data,t,x)

        return np.mean(ITE_1-ITE_0), np.mean(ITE_t-ITE_0)

    def estimation(self, data):
        x = data.x
        t = data.t
        return self.predict(t=t-t, x=x), self.predict(t=t, x=x)
    