
# coding: utf-8

import numpy as np
from scipy.linalg import sqrtm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader, Sampler, TensorDataset
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchsummary import summary
import sys 
import pandas as pd
import argparse 
import os
import time

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:90% !important; }</style>"))

parser = argparse.ArgumentParser()
parser.add_argument("--algo", type=str, help="algorithm to use -- ULD or HFHR")
parser.add_argument("--output_path", type=str, help="output path")
args = parser.parse_args()

def getDataLoader(path='data/parkinsons/parkinsons.csv'):
    df = pd.read_csv(path)
    y = df['status'].values
    df.drop(['name', 'status'], inplace=True, axis=1)
    X = StandardScaler().fit_transform(df.values)
    
    # train-test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

    print('number of features: {}\n'.format(X_train.shape[1]))
    print('number of train data: {}'.format(len(X_train)))
    print('number of test data: {}'.format(len(X_test)))
    
    def dataloader(X, y, batch_size):
        tensor_x = torch.stack([torch.Tensor(x) for x in X])
        tensor_y = torch.Tensor(y).type(torch.long)
        return DataLoader(TensorDataset(tensor_x, tensor_y), batch_size=batch_size)
    
    train_loader = dataloader(X_train, y_train, batch_size=len(X_train))
    test_loader = dataloader(X_test, y_test, batch_size=len(X_test))
    
    return train_loader, test_loader

train_loader, test_loader = getDataLoader()

class HFHR(Optimizer):

    def __init__(self, params, h=0.1, gamma=2.0, alpha=1.0, device=torch.device('cuda')):
        defaults = dict(h=h, gamma=gamma, alpha=alpha, device=device)
        super(HFHR, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [torch.zeros_like(param, device=device) for param in group['params']]
        
    def step(self):
        with torch.no_grad():
            for group in self.param_groups:
                h = group['h']
                gamma = group['gamma']
                alpha = group['alpha']
                device = group['device']

                for q, p in zip(group['params'], group['momentums']):
                    noise1 = torch.randn_like(q, device=device)
                    q.data.add_(h * (p - alpha * q.grad) + np.sqrt(2*alpha*h) * noise1)
                    
                    noise2 = torch.randn_like(p, device=device)
                    p.data = np.exp(-gamma * h) * p - (1 - np.exp(-gamma * h)) / gamma * q.grad + np.sqrt(1 - np.exp(-2*gamma*h)) * noise2

class HFHR2(Optimizer):

    def __init__(self, params, h=0.1, gamma=2.0, alpha=1.0, device=torch.device('cuda')):
        defaults = dict(h=h, gamma=gamma, alpha=alpha, device=device)
        super(HFHR2, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [torch.zeros_like(param, device=device) for param in group['params']]
        
        self.prepare_M(h / 2, gamma)
    
    def prepare_M(self, t, gamma):
        var_L = (2*t*gamma+4*np.exp(-gamma*t)-np.exp(-2*gamma*t)-3) / gamma**2
        var_K = 1 - np.exp(-2*gamma*t)
        E_LK = (1 - np.exp(-gamma*t))**2 / gamma
        self.M = sqrtm(np.array([[var_L, E_LK],[E_LK, var_K]]))

    def phi_flow_exact(self):
        with torch.no_grad():
            for group in self.param_groups:
                h = group['h']
                gamma = group['gamma']
                alpha = group['alpha']
                device = group['device']

                for q, p in zip(group['params'], group['momentums']):
                    noise_q_tmp = torch.randn_like(q, device=device)
                    noise_p_tmp = torch.randn_like(p, device=device)

                    noise_q = self.M[0][0] * noise_q_tmp + self.M[0][1] * noise_p_tmp
                    noise_p = self.M[1][0] * noise_q_tmp + self.M[1][1] * noise_p_tmp

                    q.data = q + (1 - np.exp(-gamma*h/2)) / gamma * p + noise_q
                    p.data = np.exp(-gamma*h/2) * p + noise_p

    def psi_flow_approx(self):
        with torch.no_grad():
            for group in self.param_groups:
                h = group['h']
                gamma = group['gamma']
                alpha = group['alpha']
                device = group['device']

                for q, p in zip(group['params'], group['momentums']):
                    noise = torch.randn_like(q, device=device)
                    q.data = q - alpha * q.grad * h + np.sqrt(2 * alpha * h) * noise
                    p.data = p - q.grad * h

class Net(nn.Module):
    
    def __init__(self, dim_input=22, dim_hidden=50, dim_output=2):
        super(Net, self).__init__()
        
        self.dim_input = dim_input
        self.fc1 = nn.Linear(dim_input, dim_hidden)
        self.fc2 = nn.Linear(dim_hidden, dim_output)

    def forward(self, x):
        if len(x.shape) > 2:
            x = x.view(-1, self.dim_input)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
    
class MultipleNets(nn.Module):
    
    def __init__(self, n_samples=1, dim_input=22, dim_hidden=50, dim_output=2):
        super(MultipleNets, self).__init__()
        self.n_samples = n_samples
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_output = dim_output
        
        self.nets = [Net(dim_input, dim_hidden, dim_output) for _ in range(n_samples)]
        for i, net in enumerate(self.nets):
            self.add_module('net-{}'.format(i), net)
    
    def forward(self, x):
        return [net.forward(x) for net in self.nets]

# net = Net().to(device)
# summary(net, (1, 784))

def train(model, device, train_loader, criterion, optimizer):
    model.train()
    
    for idx, (data, target) in enumerate(train_loader):
        
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        logits = model(data)
        loss = sum([criterion(logit, target) for logit in logits])
        if reg > 0:
            prior = reg/2 * sum([param.pow(2).sum() for param in model.parameters()])
            loss += prior
        
        if isinstance(optimizer, HFHR):
            loss.backward()
            optimizer.step()
        elif isinstance(optimizer, HFHR2):
            optimizer.phi_flow_exact()
            loss.backward()
            optimizer.psi_flow_approx()
            optimizer.phi_flow_exact()
  
def evaluate(model, device, train_loader, test_loader, verbose=True):
    model.eval()
    
    with torch.no_grad():
        correct = 0
        data, target = next(iter(train_loader))
        data, target = data.to(device), target.to(device)
        logits = model(data)
        prob = sum([F.softmax(logit, dim=1) for logit in logits])
        pred = prob.argmax(dim=1).long()
        correct += torch.eq(pred, target).sum()
        train_acc = correct.float() / len(train_loader.dataset)
        train_llh = torch.FloatTensor([eval_criterion(logit, target) for logit in logits]).mean().cpu()
        
        correct = 0
        data, target = next(iter(test_loader))
        data, target = data.to(device), target.to(device)
        logits = model(data)
        prob = sum([F.softmax(logit, dim=1) for logit in logits])
        pred = prob.argmax(dim=1).long()
        correct += torch.eq(pred, target).sum()
        test_acc = correct.float() / len(test_loader.dataset)
        test_llh = torch.FloatTensor([eval_criterion(logit, target) for logit in logits]).mean().cpu()     
        
    template = 'Epoch: {}\t{} Error: {:.2f}%\t{} Log-Likelihood: {:.4f}\t{} Error: {:.2f}%\t{} Log-Likelihood: {:.4f}\n'
    info = template.format(epoch, 'Train', (1 - train_acc) * 100, 'Train', train_llh, 'Test', (1 - test_acc) * 100, 'Test', test_llh)
    if verbose:
        print(info)
    logger.write(info)

torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False 

n_samples = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epoches = 500
reg = 1
h = 2e-2

if args.algo == 'ULD':
    alpha = 0
elif args.algo == 'HFHR':
    alpha = 1

n_exp = 10

if not os.path.exists(args.output_path):
    os.mkdir(args.output_path)

start = time.time()
for i in range(n_exp):
    torch.manual_seed(i)
    
    logger = open(f'{args.output_path}/HFHR({alpha}-{h}-{epoches}-{n_samples}-{i}).txt', 'w+')
    logger.write('--- Experiment Configuration ---\n\n')
    logger.write('dataset: parkinsons\n')
    logger.write(f'n_samples: {n_samples}\n')
    logger.write(f'reg: {reg}\n')
    logger.write(f'h: {h}\n')

    logger.write('\n' + '-' * 80 + '\n\n')

    model = MultipleNets(n_samples=n_samples, dim_input=22, dim_hidden=10, dim_output=2).to(device)
    params = [{'params': net.parameters()} for net in model.nets]
    criterion = nn.CrossEntropyLoss(reduction='sum')
    eval_criterion = nn.CrossEntropyLoss()

    if args.algo == 'ULD':
        optimizer = HFHR(params, h=h, alpha=0, gamma=20)
    elif args.algo == 'HFHR':
        optimizer = HFHR2(params, h=h, alpha=1, gamma=20)

    for epoch in range(epoches):
        evaluate(model, device, train_loader, test_loader, verbose=False)
        train(model, device, train_loader, criterion, optimizer)

    logger.close()
    
    print(f'{args.algo}: seed={i} finishes.')

end = time.time()
print(f"\naverage time per exp: {(end - start) / n_exp}")