import click
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.linalg import qr
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(0)
torch.manual_seed(0)

def rand_ortho(dim):
    tmp = np.random.randn(dim, dim)
    Q, R = qr(tmp)
    return Q

def get_dataset(dataset_type=None, input_dim=None, output_dim=None, dataset_size=None):
    if dataset_type == 'synth':
        inputs, outputs = [], []
        weights = np.random.random(size=(output_dim, input_dim))
        weights = weights + np.ones_like(weights)

        for idx in range(dataset_size):
            inp = np.random.random(size=(input_dim)) 
            output = weights @ inp
            inputs.append(inp)
            outputs.append(output)

        tensor_inputs = torch.Tensor(inputs)
        tensor_outputs = torch.Tensor(outputs)
        dataset = TensorDataset(tensor_inputs, tensor_outputs)
        return dataset, tensor_inputs, tensor_outputs
    else:
        assert()

def get_augmented(input, step, augment_type=None, augment_params=None):
    if augment_type == 'gaussian':
        var = augment_params['init_var']
        if augment_params['decay_type'] == 'exp':
            var = var * math.exp(- step * augment_params['decay_rate'])
        elif augment_params['decay_type'] == 'power':
            var = var * ((step / 50 + 1) ** (- augment_params['decay_rate']))
        std = math.sqrt(var)
        shift = torch.normal(torch.zeros_like(input), std)
        return input + shift
    elif augment_type == 'rand-proj':
        mg = augment_params['init_1mg']
        if augment_params['decay_type'] == 'exp':
            mg = mg * math.exp(- step * augment_params['decay_rate'])
        elif augment_params['decay_type'] == 'power':
            mg = mg * ((step / 50 + 1) ** (- augment_params['decay_rate']))
        gamma = 1 - mg
        dim = input.size()[1]
        proj_dim = int(gamma * dim)
        rot = torch.Tensor(rand_ortho(dim))
        input = torch.mm(input, rot)
        input[:,proj_dim:] = 0
        input = torch.mm(input, torch.transpose(rot, 0, 1))
        return input
    else:
        return input

def get_lr(lr_params, step):
    lr = lr_params['init_lr']
    if lr_params['decay_type'] == 'exp':
        lr = lr * math.exp(- step * lr_params['decay_rate'])
    elif lr_params['decay_type'] == 'power':
        lr = lr * ((step / 20 + 1) ** (- lr_params['decay_rate']))
    return lr
    
class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.linear.weight.data.uniform_(0.0, 1.0)

    def forward(self, x):
        out = self.linear(x)
        return out
        
@click.command()
@click.option('--dataset_type', type=str, default='synth')
@click.option('--input_dim', type=int, default=400)
@click.option('--output_dim', type=int, default=1)
@click.option('--dataset_size', type=int, default=100)
@click.option('--augment_type', type=str, default='gaussian')
@click.option('--augment_params', type=str, default='0.1,exp,0.01')
@click.option('--lr', type=float, default=0.1)
@click.option('--lr_params', type=str, default='0.1,exp,0.01')
@click.option('--steps', type=int, default=1000)
@click.option('--batch_size', type=int, default=-1)
@click.option('--out_file', type=str)
@click.option('--cuda/--no-cuda', default=False)
def main(dataset_type=None, input_dim=None, output_dim=None, dataset_size=None,
         augment_type=None, augment_params=None, lr=None, lr_params=None, steps=None,
         batch_size=None, out_file=None, cuda=None):
    lr_params = lr_params.split(',')
    lr_params = {'init_lr': float(lr_params[0]),
                 'decay_type': lr_params[1],
                 'decay_rate': float(lr_params[2])}
    
    if augment_type == 'gaussian':
        augment_params = augment_params.split(',')
        augment_params = {'init_var': float(augment_params[0]),
                          'decay_type': augment_params[1],
                          'decay_rate': float(augment_params[2])}
    elif augment_type == 'rand-proj':
        augment_params = augment_params.split(',')
        augment_params = {'init_1mg': float(augment_params[0]),
                          'decay_type': augment_params[1],
                          'decay_rate': float(augment_params[2])}        
     
    dataset, inputs, outputs = get_dataset(dataset_type=dataset_type,
                                           input_dim=input_dim,
                                           output_dim=output_dim,
                                           dataset_size=dataset_size)
    X = torch.transpose(inputs, 0, 1)
    Y = torch.transpose(outputs, 0, 1)
    Q, R = torch.qr(X)
    proj = torch.mm(Q, torch.transpose(Q, 0, 1))
    Wmin = torch.mm(Y, torch.mm(torch.inverse(torch.mm(torch.transpose(X, 0, 1), X)),
                                torch.transpose(X, 0, 1)))

    print(torch.norm(Y - torch.mm(Wmin, X)))

    if batch_size == -1:
        data_loader = DataLoader(dataset, batch_size=dataset_size, shuffle=True, num_workers=1, pin_memory=True)
    else:
        sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=batch_size)
        data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=1, pin_memory=True)
    criterion = nn.MSELoss()
    if cuda:
        criterion = criterion.cuda()
    model = LinearModel(input_dim, output_dim)
    if cuda:
        model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.0, weight_decay=0.0)

    loss_vals = []
    orth_vals = []
    step_vals = []
    dist_vals = []
    wnorm_vals = []
    step = 0
    is_done = False
    while not is_done:
        for input, target in data_loader:
            if cuda:
                input, target = input.cuda(), target.cuda()
            aug_input = get_augmented(input, step, augment_type=augment_type, augment_params=augment_params)
            if batch_size > 0:
                aug_input = aug_input * math.sqrt(dataset_size / batch_size)
                target = target * math.sqrt(dataset_size / batch_size)
            output = model(aug_input)
            loss = criterion(output, target)
            if batch_size > 0:
                loss = loss * (batch_size / dataset_size)
            
            weights = model.linear.weight.cpu().detach()
            loss_val = torch.norm(Y - torch.mm(weights, X)).item() ** 2 / X.size()[1]
            para = torch.matmul(weights, proj)
            orthog = weights - para
            para_val = torch.norm(para).item()
            orth_val = torch.norm(orthog).item()
            dist = torch.norm(Wmin - weights).item()
            wnorm = torch.norm(weights).item()
            loss_vals.append(loss_val)
            dist_vals.append(dist)
            orth_vals.append(orth_val)
            step_vals.append(step)
            wnorm_vals.append(wnorm)

            optimizer.zero_grad()
            loss.backward()
            lr = get_lr(lr_params, step)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            optimizer.step()
            
            if step % 100 == 0:
                print('Step {:4} loss_val {:10.3f} batch_loss {:8.3f} para_val {:.3f} orth_val {:.3f} dist {:.3f} Wnorm {:.3f} Wmin {:.3f}'.format(step, loss_val, loss.cpu().detach(), para_val, orth_val, dist, wnorm, torch.norm(Wmin).item()))
            step = step + 1
            if step == steps:
                is_done = True
                break

    df = pd.DataFrame({'Loss': loss_vals,
                       'Ortho Norm': orth_vals,
                       'Dist': dist_vals,
                       'WNorm': wnorm_vals})
    df.to_hdf(out_file, key='df')

if __name__ == '__main__':
    main()
