#!/usr/bin/env python
# coding: utf-8

import numpy as np
import random
import time
import torch
import torch.nn as nn
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt 
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

w_inv = 10.0
w_ind = 0.01
w_mini = 0.00

print('Using weight w_inv:%.2f, w_ind:%.2f, w_mini:%.2f' % (w_inv, w_ind, w_mini))

device = torch.device('cuda:0')
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
set_seed(1111)

# Feed-forward Network
class FFN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden=512, n_layers=3, activation='relu'):
        super(FFN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden = hidden
        assert n_layers >= 1
        self.n_layers = n_layers
        assert activation in ['relu', 'tanh', 'sigmoid']
        act_dict = {
            'relu': nn.ReLU(),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid()
        }
        self.activation = act_dict[activation]

        self.layers = nn.ModuleList()
        if self.n_layers == 1:
            self.layers.append(nn.Linear(self.input_dim, self.output_dim))
        else:
            self.layers.append(nn.Linear(self.input_dim, self.hidden))
            for _ in range(self.n_layers-2):
                self.layers.append(self.activation)
                self.layers.append(nn.Linear(self.hidden, self.hidden))
            self.layers.append(self.activation)
            self.layers.append(nn.Linear(self.hidden, self.output_dim))
        self.to(device)
    def forward(self, x):
        output = x
        for layer in self.layers:
            output = layer(output)
        return output

# Generating data
class DataGenerator():
    # 'dims' should be list or tuple with 10 elements (size 3 multi-varible model), representing dimension of s_i and v_j, respectively
    def __init__(self, dims, n_samples=None):
        self.dims = dims
        assert len(self.dims) == 10
        self.func_v1 = FFN(dims[0] + dims[2] + dims[4] + dims[6], dims[7], hidden = 64, n_layers=2, activation='tanh')
        self.func_v2 = FFN(dims[1] + dims[2] + dims[5] + dims[6], dims[8], hidden = 64, n_layers=2, activation='tanh')
        self.func_v3 = FFN(dims[3] + dims[4] + dims[5] + dims[6], dims[9], hidden = 64, n_layers=2, activation='tanh')
        self.static_data = False
        if not n_samples is None:
            self.static_data = True
            self.data = self._gen(n_samples)
            self.pointer = 0
        print('Created data with dimensions:')
        for i in range(7):
            print('s%d: %d' % (i, self.dims[i]))
        for i in range(3):
            print('v%d: %d' % (i, self.dims[7+i]))
    def _batch_gen(self, batch_size=128):
        with torch.no_grad():
            s = []
            for i in range(7):
                dic = {}
                dic['raw'] = torch.randn([batch_size, self.dims[i]]).to(device)
                s.append(dic)
            v1 = self.func_v1(torch.hstack([s[0]['raw'], s[2]['raw'], s[4]['raw'], s[6]['raw']]))
            v2 = self.func_v2(torch.hstack([s[1]['raw'], s[2]['raw'], s[5]['raw'], s[6]['raw']]))
            v3 = self.func_v3(torch.hstack([s[3]['raw'], s[4]['raw'], s[5]['raw'], s[6]['raw']]))
        return [var['raw'].detach() for var in s] + [var.detach() for var in [v1, v2, v3]]
    def _gen(self, n):
        data = []
        batches = []
        bs = 200
        if n <= bs:
            return self._batch_gen(n)
        for _ in range((n - 1) // bs + 1):
            batches.append(self._batch_gen(bs))
        for i in range(10):
            data.append(torch.vstack([batch[i] for batch in batches])[:n])
        return data
    def sampling(self, n):
        if self.static_data:
            output = [var[self.pointer:self.pointer+n] for var in self.data]
            rest = n - output[0].shape[0]
            self.pointer += n
            if self.pointer >= self.data[0].shape[0]:
                self.pointer = 0
            if rest > 0:
                additional = [var[:rest] for var in self.data]
                output = [torch.vstack([a, b]) for a,b in zip(output, additional)]
                self.pointer = rest
        else:
            output = self._gen(n)
        return [var.detach().cpu().numpy() for var in output]
    def generate(self, n):
        output = self._gen(n)
        return [var.detach().cpu().numpy() for var in output]

class CLUBModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CLUBModel, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.pred_mu = FFN(self.input_dim, self.output_dim)
        self.pred_logvar = FFN(self.input_dim, self.output_dim)
    def forward(self, x, y):
        mu, logvar = self.pred_mu(x), self.pred_logvar(x)
        nll_loss = torch.mean((mu - y) ** 2 / logvar.exp() + logvar) # unnormalized
        permed_index = torch.randperm(y.shape[0])
        club_loss = torch.mean((mu - y[permed_index]) ** 2 / logvar.exp() / 2.0) - torch.mean((mu - y) ** 2 / logvar.exp() / 2.0)
        return nll_loss, club_loss

class Decoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Decoder, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.pred_mu = FFN(self.input_dim, self.output_dim)
    def forward(self, x, y):
        mu= self.pred_mu(x)
        nll_loss = torch.mean((mu - y) ** 2)
        return nll_loss

# Identification model
class Model(nn.Module):
    # 'vis_dims' should be list or tuple with 2 elements, representing dimension of v1, v2, respectively
    # We assume the dimension of latent variables as d_s1 = d_v1, d_z = d_v1 + d_v2, d_s2 = d_v2
    def __init__(self, vis_dims, latent_dims=None):
        super(Model, self).__init__()
        assert len(vis_dims) == 2
        self.v1_dim, self.v2_dim = vis_dims
        if latent_dims is None:  
            self.z_dim = self.v1_dim + self.v2_dim
            self.s1_dim = self.v1_dim
            self.s2_dim = self.v2_dim
        else:
            assert len(latent_dims) == 3
            self.s1_dim, self.z_dim, self.s2_dim = latent_dims
        self.enc_ml = nn.ModuleList()
        self.aux_ml = nn.ModuleList()
        
        # encoder & decoder
        self.encoder = FFN(self.v1_dim + self.v2_dim, self.z_dim + self.s1_dim + self.s2_dim) # [v1, v2] -> [s1, z, s2]
        self.decoder_v1 = Decoder(self.z_dim + self.s1_dim, self.v1_dim) # [s1, z] -> v1
        self.decoder_v2 = Decoder(self.z_dim + self.s2_dim, self.v2_dim) # [z, s2] -> v2
        self.enc_ml.extend([self.encoder, self.decoder_v1, self.decoder_v2])
        
        # auxiliary predictors
        self.indm_s1_zs2 = CLUBModel(self.s1_dim, self.z_dim + self.s2_dim)
        self.indm_z_s1s2 = CLUBModel(self.z_dim, self.s1_dim + self.s2_dim)
        self.indm_s2_s1z = CLUBModel(self.s2_dim, self.s1_dim + self.z_dim)
        self.aux_ml.extend([self.indm_s1_zs2, self.indm_z_s1s2, self.indm_s2_s1z])

        self.optim_enc = torch.optim.AdamW(self.enc_ml.parameters(), lr=1e-3)
        self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=5000, gamma=0.2)
        self.optim_aux = torch.optim.AdamW(self.aux_ml.parameters(), lr=1e-3)
        self.scheduler_aux = torch.optim.lr_scheduler.StepLR(self.optim_aux, step_size=5000, gamma=0.2)
        self.to(device)
    def forward(self, v1, v2, mode='training'):
        latent_code = self.encoder(torch.hstack([v1, v2]))
        s1 = latent_code[:, :self.s1_dim]
        z = latent_code[:, self.s1_dim:latent_code.shape[1]-self.s2_dim]
        s2 = latent_code[:, latent_code.shape[1]-self.s2_dim:]
        
        loss_dic = {}
        
        loss_inv = (self.decoder_v1(torch.hstack([s1, z]), v1) + self.decoder_v2(torch.hstack([z, s2]), v2)) / 2
        loss_dic['loss_inv'] = loss_inv

        nll_zs2, ind_s1 = self.indm_s1_zs2(s1, torch.hstack([z, s2]))
        nll_s1s2, ind_z = self.indm_z_s1s2(z, torch.hstack([s1, s2]))
        nll_s1z, ind_s2 = self.indm_s2_s1z(s2, torch.hstack([s1, z]))
        loss_ind = (ind_s1 + ind_z + ind_s2) / 3
        loss_pred = (nll_zs2 + nll_s1s2 + nll_s1z) / 3
        loss_dic['loss_ind'] = loss_ind
        loss_dic['loss_pred'] = loss_pred

        if mode == 'training':
            return loss_dic
        else:
            return s1, z, s2
    
    def step(self, v1, v2):
        self.train()
        v1, v2 = torch.tensor(v1).to(device), torch.tensor(v2).to(device)

        for i in range(5):
            loss_dic = self.forward(v1, v2)
            loss_aux = loss_dic['loss_pred']
            self.optim_aux.zero_grad()
            loss_aux.backward()
            self.optim_aux.step()

        loss_dic = self.forward(v1, v2)
        loss_enc = w_inv * loss_dic['loss_inv'] + w_ind * loss_dic['loss_ind']
        
        self.optim_enc.zero_grad()
        loss_enc.backward()
        self.optim_enc.step()

        self.scheduler_enc.step()
        self.scheduler_aux.step()

        loss_dic['loss_aux'] = loss_aux
        loss_dic['loss_enc'] = loss_enc
        for k, v in loss_dic.items():
            loss_dic[k] = v.detach().cpu()
        return loss_dic
    def predict(self, v1, v2):
        self.eval()
        v1, v2 = torch.tensor(v1).to(device), torch.tensor(v2).to(device)
        batch_size = 200
        num = v1.shape[0]
        all_pred = []
        for i in range((num - 1) // batch_size + 1):
            batch_v1  = v1[batch_size*i:batch_size*(i+1)]
            batch_v2  = v2[batch_size*i:batch_size*(i+1)]
            pred = self.forward(batch_v1, batch_v2, mode='test')
            all_pred.append([value.detach().cpu().numpy() for value in pred])
        return [np.vstack([batch[i] for batch in all_pred]) for i in range(3)]

class Predictor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Predictor, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.batch_size = 200
        self.threshold = 1e-2
        self.model = FFN(input_dim, output_dim, hidden=1024, n_layers=3)
        self.optim = torch.optim.AdamW(self.parameters(), lr=1e-3)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optim, 
            mode='min', 
            factor=0.1, 
            patience=5, 
            verbose=False, 
            threshold=self.threshold, 
            threshold_mode='rel', 
            cooldown=0, 
            min_lr=0, 
            eps=1e-05)
        self.to(device)
    def forward(self, x):
        return self.model(x)
    def _step(self, x, y):
        x, y = torch.tensor(x).to(device), torch.tensor(y).to(device)
        loss_fn = nn.MSELoss()
        pred_y = self.forward(x)
        loss = loss_fn(y, pred_y)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return loss.detach().cpu()
    def predict(self, x):
        self.eval()
        x = torch.tensor(x).to(device)
        num = x.shape[0]
        all_pred_y = []
        for i in range((num - 1) // self.batch_size + 1):
            batch_x  = x[self.batch_size*i:self.batch_size*(i+1)]
            pred_y = self.forward(batch_x)
            all_pred_y.append(pred_y.detach().cpu().numpy())
        return np.vstack(all_pred_y)
    def fit_with_val(self, x, y, silent=True):
        self.train()
        num = x.shape[0]
        permed_index = np.random.permutation(num)
        num_train = np.round(num * 0.8).astype(np.int32)
        num_val = num - num_train
        train_x, train_y = x[permed_index[:num_train]], y[permed_index[:num_train]]
        val_x, val_y = x[permed_index[num_train:]], y[permed_index[num_train:]]
        n_epochs = 500
        es_steps = 10 # early stop steps
        es_count = 0
        es_loss = np.inf
        best_params = self.state_dict()
        best_ep = 0
        for e in range(n_epochs):
            for i in range((num_train - 1) // self.batch_size + 1):
                batch_x = train_x[self.batch_size*i:self.batch_size*(i+1)]
                batch_y = train_y[self.batch_size*i:self.batch_size*(i+1)]
                loss = self._step(batch_x, batch_y)
            val_loss = 0.0
            for i in range((num_val - 1) // self.batch_size + 1):
                batch_x = torch.tensor(val_x[self.batch_size*i:self.batch_size*(i+1)]).to(device)
                batch_y = torch.tensor(val_y[self.batch_size*i:self.batch_size*(i+1)]).to(device)
                with torch.no_grad():
                    pred_y = self.forward(batch_x)
                    loss = torch.sum(torch.mean((pred_y - batch_y) ** 2, dim=-1))
                val_loss += loss.detach().cpu()
            val_loss /= num_val
            self.scheduler.step(val_loss)
            if not silent and e % 10 == 0:
                print('Epoch %d, validataion loss: %f' % (e, val_loss))
            if val_loss < es_loss * (1 - self.threshold):
                es_loss = val_loss
                es_count = 0
                best_params = self.state_dict()
                best_ep = e
            else:
                es_count += 1
                if es_count == es_steps:
                    if not silent:
                        print('Early stopped at epoch %d, use params of epoch %d, loss: %f' % (e, best_ep, val_loss))
                    self.load_state_dict(best_params)
                    break
        return val_loss
    def evaluate(self, x, y):
        self.eval()
        pred_y = self.predict(x)
        mse = np.mean((pred_y - y) ** 2, axis=0)
        avg_y = np.mean(y, axis=0)
        r_square = 1.0 - mse / (np.mean((y - avg_y) ** 2, axis=0) + 1e-20)
        r_square[r_square<0] = 0.0
        return np.mean(mse), np.mean(r_square)
    
def normalize(x):
    mu = np.mean(x, axis=0)
    std = np.std(x, axis=0)
    return (x - mu) / std
def eval(x, y, silent=True):
    n, dim_x = x.shape
    dim_y = y.shape[-1]
    x, y = normalize(x), normalize(y)
    n_test = n//4
    train_x = x[:-n_test]
    train_y = y[:-n_test]
    test_x = x[-n_test:]
    test_y = y[-n_test:]
    forward_pred = Predictor(dim_x, dim_y)
    forward_pred.fit_with_val(train_x, train_y, silent)
    f_mse, f_r2 = forward_pred.evaluate(test_x, test_y)
    backward_pred = Predictor(dim_y, dim_x)
    backward_pred.fit_with_val(train_y, train_x, silent)
    b_mse, b_r2 = backward_pred.evaluate(test_y, test_x)
    return f_r2, b_r2

def F1(p1, p2):
    eps = 1e-20
    p1 = max(p1, eps)
    p2 = max(p2, eps)
    return 2*p1*p2/(p1+p2)


class DataLoader():
    def __init__(self, data, batch_size):
        self.data = data
        self.pointer = 0
        self.total = self.data[0].shape[0]
        self.bs = batch_size
    def sampling(self, n):
        output = [var[self.pointer:self.pointer+self.bs] for var in self.data]
        rest = n - output[0].shape[0]
        self.pointer += n
        if self.pointer >= self.total:
            self.pointer = 0
        if rest > 0:
            additional = [var[:rest] for var in self.data]
            output = [torch.vstack([a, b]) for a,b in zip(output, additional)]
            self.pointer = rest
        return output


var_dims = [2] * 7 + [10] * 3
sampler = DataGenerator(var_dims, n_samples=100000)
test_data = sampler.generate(20000)
n_epochs = 20001
#set_seed(int(time.time()))

def identify(train_set, test_set, ckpt_name='ckpt.pth', load_ckpt=False):

    ds1, dz, ds2, dv1, dv2 = [var.shape[1] for var in test_set]
    model = Model([dv1, dv2], latent_dims=[ds1, dz, ds2])

    if load_ckpt:
        print('Loading checkpoint ...')
        model.load_state_dict(torch.load(ckpt_name))
    else:
        print('Training ...')
        bs = 100
        train_loader = DataLoader(train_set, batch_size=bs)
        saved_r2 = []

        t0 = time.time()
        for e in range(n_epochs):
            batch_v1, batch_v2 = train_loader.sampling(bs)
            loss_dic = model.step(batch_v1, batch_v2)

            if e % 100 == 0:
                print('Epoch %d, Encoder loss: %f, Auxiliary loss: %f Invertibility: %f, Independence: %f' 
                    % (e, loss_dic['loss_enc'], loss_dic['loss_aux'], loss_dic['loss_inv'], loss_dic['loss_ind']))

            if e % 1000 == 0:
                t1 = time.time()
                print('Time: %d' % (t1-t0))
                print('Epoch %d, Learning Rate: %f ...' % (e, model.scheduler_enc.get_last_lr()[0]))
                s1, z, s2, v1, v2 = test_set
                pred_s1, pred_z, pred_s2 = model.predict(v1, v2)
                
                print('Epoch %d, Evaluating ...' % e)
                s1_fr2, s1_br2 = eval(pred_s1, s1)
                print('s1->s1_gt', s1_fr2, s1_br2)
                z_fr2, z_br2 = eval(pred_z, z)
                print('z->z_gt', z_fr2, z_br2)
                s2_fr2, s2_br2 = eval(pred_s2, s2)
                print('s2->s2_gt', s2_fr2, s2_br2)
                avg_r2 = (F1(s1_fr2, s1_br2) + F1(z_fr2, z_br2) + F1(s2_fr2, s2_br2)) / 3
                print('>'*20, 'Epoch %d, Average R2-F1: %f' % (e, avg_r2))
                saved_r2.append(avg_r2)
        torch.save(model.state_dict(), ckpt_name)

    v1, v2 = train_set
    pred_s1, pred_z, pred_s2 = model.predict(v1, v2)
    _, _, _, test_v1, test_v2 = test_set
    pred_test_s1, pred_test_z, pred_test_s2 = model.predict(test_v1, test_v2)
    return pred_s1, pred_z, pred_s2, pred_test_s1, pred_test_z, pred_test_s2


load_ckpt = False
if not os.path.exists('kn_ckpt'):
	os.mkdir('kn_ckpt')

v1, v2, v3 = [var.cpu().numpy() for var in sampler.data[-3:]]

print('running basis model on v1 <-> [v2, v3]')
train_set = [v1, np.hstack([v2, v3])]
test_set = [test_data[0], np.hstack([test_data[i] for i in [2,4,6]]), np.hstack([test_data[i] for i in [1,3,5]]), test_data[7], np.hstack([test_data[i] for i in [8,9]])]
pred_s1, pred_s357, pred_s246, pred_test_s1, pred_test_s357, pred_test_s246 = identify(train_set, test_set, ckpt_name=os.path.join('kn_ckpt', 'v1_v2v3.pth'), load_ckpt=load_ckpt)

print('running basis model on v2 <-> [v1, v3]')
train_set = [v2, np.hstack([v1, v3])]
test_set = [test_data[1], np.hstack([test_data[i] for i in [2,5,6]]), np.hstack([test_data[i] for i in [0,3,4]]), test_data[8], np.hstack([test_data[i] for i in [7,9]])]
pred_s2, pred_s367, pred_s145, pred_test_s2, pred_test_s367, pred_test_s145 = identify(train_set, test_set, ckpt_name=os.path.join('kn_ckpt', 'v2_v1v3.pth'), load_ckpt=load_ckpt)

print('running basis model on v3 <-> [v1, v2]')
train_set = [v3, np.hstack([v1, v2])]
test_set = [test_data[3], np.hstack([test_data[i] for i in [4,5,6]]), np.hstack([test_data[i] for i in [0,1,2]]), test_data[9], np.hstack([test_data[i] for i in [7,8]])]
pred_s4, pred_s567, pred_s123, pred_test_s4, pred_test_s567, pred_test_s123 = identify(train_set, test_set, ckpt_name=os.path.join('kn_ckpt', 'v3_v1v2.pth'), load_ckpt=load_ckpt)

print('running basis model on s357 <-> s367')
train_set = [pred_s357, pred_s367]
test_set = [test_data[4], np.hstack([test_data[i] for i in [2,6]]), test_data[5], pred_test_s357, pred_test_s367]
pred_s5, pred_s37, pred_s6, pred_test_s5, pred_test_s37, pred_test_s6 = identify(train_set, test_set, ckpt_name=os.path.join('kn_ckpt', 's357_s367.pth'), load_ckpt=load_ckpt)

print('running basis model on s37 <-> s567')
train_set = [pred_s37, pred_s567]
test_set = [test_data[2], test_data[6], np.hstack([test_data[i] for i in [4,5]]), pred_test_s37, pred_test_s567]
pred_s3, pred_s7, pred_s56, pred_test_s3, pred_test_s7, pred_test_s56 = identify(train_set, test_set, ckpt_name=os.path.join('kn_ckpt', 's37_s567.pth'), load_ckpt=load_ckpt)

print('Final evaluation ...')
pred_test_s = [pred_test_s1, pred_test_s2, pred_test_s3, pred_test_s4, pred_test_s5, pred_test_s6, pred_test_s7]
r2s = []
for i in range(7):
    fr2, br2 = eval(pred_test_s[i], test_data[i])
    r2f1 = F1(fr2, br2)
    r2s.append(r2f1)
    print('s%d->s%d_gt' % (i+1, i+1), fr2, br2)
avg_r2 = np.mean(r2s)
print('Average R2-F1: %f' % avg_r2)
