import torch
from torch import nn, optim
from torch.nn import functional as F
from kloader import KTensorDataLoader
from data import gen_sparse_linear_classification
from models import DiagonalNet
import argparse
import yaml
import wandb
from optimizer import AdamExponent


def train_diagonal_net(
    d: int,
    k: int,
    delta: float,
    n_train: int,
    lr: float,
    optimizer_type: str,
    beta1: float = None,
    beta2: float = None,
    n_test: int = 10000,
    data_seed: int = 24181325235,
    L: int = 2,
    alpha: float = 1.0,
    steps: int = 1_000_000,
    eval_first: int = 1000,
    eval_period: int = 1000,
    batch_size: int = 10_000_000,
    eval_batch_size: int = 1000,
    exponent: float = 0.5,
    save_dir: str = None,
    save_freq: int = -1,
):
    
    assert k <= d, "k must be less than or equal to d"
    assert optimizer_type in ['SGD', 'SGDM', 'Adam', 'AdamE'], "optimizer_type must be SGD, SGDM, Adam or AdamE"
    assert torch.cuda.is_available()
    torch.set_default_device('cuda')

    config = yaml.load(open('.config.yml'), Loader=yaml.FullLoader)
    my_config = {
        'd': d, 'k': k, 'delta': delta, 'n_train': n_train, 'lr': lr,
        'optimizer_type': optimizer_type, 'beta1': beta1, 'beta2': beta2,
        'n_train': n_train, 'steps': steps, 'exponent': exponent
    }

    name = f"{optimizer_type}-N{n_train}-D{d}-K{k}-LR{lr}-DT{delta}-ST{steps}"
    if optimizer_type == 'AdamE':
        name += f"-E{exponent}"
    run = wandb.init(
        project="diagonal-net-loss-AdamE",
        entity=config['wandb_entity'],
        name=name,
        config=my_config
    )

    ## save snapshot of all code under current directory .
    wandb.run.log_code(".")
    # if save_dir does not exist, create it
    if save_freq > 0 and save_dir is not None:
        import os
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

    train_data, test_data = gen_sparse_linear_classification(data_seed, n_train, n_test, d, k, 'cuda')
    train_loader = KTensorDataLoader(train_data, batch_size=min(batch_size, train_data[0].shape[0]), shuffle=True, drop_last=True)
    train_loader_for_eval = KTensorDataLoader(train_data, batch_size=eval_batch_size, shuffle=False, drop_last=False)
    test_loader = KTensorDataLoader(test_data, batch_size=eval_batch_size, shuffle=False, drop_last=False)

    model = DiagonalNet(alpha=alpha, L=L, dimD=d)
    wandb.watch(model)

    print('steps per epoch:', len(train_loader))

    total_epochs = (steps + len(train_loader) - 1) // len(train_loader)

    def mse_loss(out, y):
        return F.mse_loss(out, y)

    criterion = mse_loss
    if optimizer_type == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=lr)
    elif optimizer_type == 'SGDM':
        optimizer = optim.SGD(
            model.parameters(),
            lr=lr,
            momentum=0.9,
            dampening=0,
            weight_decay=0,
            nesterov=False
        )
    elif optimizer_type == 'Adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            eps=1e-06,
            amsgrad=False
        )
    else: ## AdamE
        optimizer = AdamExponent(
            model.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            eps=1e-06,
            exponent=exponent,
            weight_decay=0,
        )

    @torch.no_grad()
    def eval_model(loader):
        loss = 0
        n = 0
        for batch_x, batch_y in loader.iter():
            out = model(batch_x)[:, 0]
            n += batch_x.shape[0]
            loss += criterion(out, batch_y).item() * batch_x.shape[0]
        return loss / n


    @torch.no_grad()
    def get_model_stats():
        stats = {}
        total_norm2 = 0
        for name, param in model.named_parameters():
            cur_norm2 = (param ** 2).sum().item()
            stats[f'norm/{name}'] = cur_norm2 ** 0.5
            total_norm2 += cur_norm2
        stats[f'total_norm'] = total_norm2 ** 0.5
        return stats

    model.train()

    cur_step = 0
    for eid in range(1, total_epochs):
        for bid, (batch_x, batch_y) in train_loader.enum():
            if cur_step % eval_period == 0 or cur_step <= eval_first:
                model.eval()

                log = {}
                train_loss = eval_model(train_loader_for_eval)
                log.update({ 'eval_train/loss': train_loss })
                test_loss = eval_model(test_loader)
                log.update({ 'eval_test/loss': test_loss })
                log.update(get_model_stats())
                log.update({ 'epoch': eid, 'train/step_in_epoch': bid, 'train/step': cur_step })
                wandb.log(log)
                
                ## print log, delete later
                print(f"Step {cur_step} / {steps}, Epoch {eid}, Batch {bid}")
                print(f"Train loss: {train_loss:.9f}, Test loss: {test_loss:.9f}")

                model.train()
            
            optimizer.zero_grad(set_to_none=True)
            out = model(batch_x)[:, 0]
            ## for each batch, draw fresh iid noise from {-delta, delta}
            signs = (torch.randint(0, 2, size=batch_y.shape) * 2 - 1).to(batch_y.device).float()
            noise = delta * signs
            # noise = torch.empty_like(batch_y).uniform_(-delta, delta) ## this is unif [-delta, delta]
            loss = criterion(out, batch_y + noise)
            loss.backward()

            if save_freq > 0 and cur_step % save_freq == 0:
                print(f"Saving model at step {cur_step}")
                torch.save(model.state_dict(), f"{save_dir}/model-{cur_step}.pt")

            optimizer.step()

            cur_step += 1

    run.finish()

    model.eval()
    final_train_loss = eval_model(train_loader_for_eval)
    final_test_loss = eval_model(test_loader)
    return final_train_loss, final_test_loss


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train diagonal net')
    parser.add_argument('--d', type=int, default=10000, help='dimension of data')
    parser.add_argument('--k', type=int, default=50, help='rank of data')
    parser.add_argument('--delta', type=float, default=0.5, help='noise level')
    parser.add_argument('--n_train', type=int, default=420, help='number of training samples')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--optimizer_type', type=str, default='SGD', help='optimizer type')
    parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
    args = parser.parse_args()

    args.d = 10000
    args.k = 50
    args.delta = 0.5
    args.n_train = 420
    args.lr = 0.01
    # args.optimizer_type = 'SGD'
    
    train_diagonal_net(
        d=args.d,
        k=args.k,
        delta=args.delta,
        n_train=args.n_train,
        lr=args.lr,
        optimizer_type=args.optimizer_type,
        beta1=args.beta1,
        beta2=args.beta2,
        save_dir=f"../save/diagonalnet-{args.optimizer_type}-n{args.n_train}",
        save_freq=10000,
    )