import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as data_utils
import numpy as np
from torchvision import datasets
from torchvision import transforms
from methods.strategies import *
import argparse
from misc import utils
from sklearn.datasets import fetch_covtype
from sklearn.preprocessing import LabelEncoder
import tqdm
from itertools import chain
import os
import pandas as pd
from backpack.custom_module.scale_module import ScaleModule
import math

parser = argparse.ArgumentParser()
strats = ['random', 'constant001', 'constant01', 'constant1', 'constant10', 'neuralucb', 'marglik_posthoc', 'marglik_online']
parser.add_argument('--strategy', choices=strats, default='marglik_posthoc')
parser.add_argument('--dataset', choices=['bean', 'letter', 'magic', 'avila', 'pendigits', 'covertype'], default='covertype')
parser.add_argument('--n_hidden_layers', type=int, default=2)
parser.add_argument('--width', type=int, default=100)
parser.add_argument('--ntk_param', default=False, action='store_true')
parser.add_argument('--n_epochs', type=int, default=500)
parser.add_argument('--T', type=int, default=10000)
parser.add_argument('--T_burnin', type=int, default=10)
parser.add_argument('--retrain_every', type=int, default=100)
parser.add_argument('--cuda', default=False, action='store_true')
parser.add_argument('--silent', default=False, action='store_true')
parser.add_argument('--save_only_gamma', default=False, action='store_true')
parser.add_argument('--randseed', type=int, default=9999)
args = parser.parse_args()

if args.ntk_param and args.strategy != 'neuralucb':
    assert False, '--ntk_param must be only used with --strategy=neuralucb'

np.random.seed(args.randseed)
torch.manual_seed(args.randseed)
DEVICE = 'cuda' if args.cuda else 'cpu'

# Dataset
if args.dataset == 'covertype':
    X, y = fetch_covtype(data_home='data/', random_state=args.randseed, shuffle=True, return_X_y=True)
    dataset = data_utils.TensorDataset(torch.tensor(X).float(), torch.tensor(y).long())
    K = 7
    D = 54
else:
    dataset, D, K = utils.load_raw_uci_dset(args.dataset)

dataset_burnin, dataset_rest  = data_utils.random_split(dataset, (args.T_burnin, len(dataset)-args.T_burnin))
burnin_loader = data_utils.DataLoader(dataset_burnin, batch_size=1, shuffle=True)
dataloader = data_utils.DataLoader(dataset_rest, batch_size=1, shuffle=True)

# Base neural net
def get_net():
    net = nn.Sequential(
        nn.Linear(K*D, args.width),
        ScaleModule(1/math.sqrt(K*D)) if args.ntk_param else nn.Identity(),
        nn.ReLU(),
        *list(chain(*[
            (
                nn.Linear(args.width, args.width),
                ScaleModule(1/math.sqrt(args.width)) if args.ntk_param else nn.Identity(),
                nn.ReLU()
            )
            for _ in range(args.n_hidden_layers-1)
        ])),
        nn.Linear(args.width, 1),
        ScaleModule(1/math.sqrt(args.width)) if args.ntk_param else nn.Identity()
    )

    if args.ntk_param:
        # Follow initialization of Jacot et al.
        def weights_init(m):
            classname = m.__class__.__name__
            if classname.find('Linear') != -1:
                torch.nn.init.normal_(m.weight, mean=0., std=1.)
                torch.nn.init.normal_(m.bias, mean=0., std=1.)
        net.apply(weights_init)

    return net


if args.strategy != 'random':
    # Burn-in to gather initial data points (train_X, train_Y)
    train_X, train_Y = [], []

    for feat, best_action in burnin_loader:
        feat = feat.squeeze().flatten()
        assert feat.shape == (D,)
        assert best_action.shape == (1,)
        ctx = utils.features_to_contexts(feat, K)  # (K, D*K)

        # Pull random lever during burn-in
        action = torch.randint(K, size=(1,))
        reward = 1 if action == best_action else 0

        # Gather data
        train_X.append(ctx[action])
        train_Y.append(reward)

    # Switch to the main strategy after burn-in
    train_X = torch.cat(train_X, dim=0)
    train_Y = torch.tensor(train_Y).float().unsqueeze(-1)

    assert train_X.shape == (args.T_burnin, D*K)
    assert train_Y.shape == (args.T_burnin, 1)

if args.strategy == 'random':
    strategy = RandomStrategy(K)
elif args.strategy == 'constant001':
    strategy = ConstantUCBStrategy(K, get_net, train_X, train_Y, gamma=0.01, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'constant01':
    strategy = ConstantUCBStrategy(K, get_net, train_X, train_Y, gamma=0.1, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'constant1':
    strategy = ConstantUCBStrategy(K, get_net, train_X, train_Y, gamma=1, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'constant10':
    strategy = ConstantUCBStrategy(K, get_net, train_X, train_Y, gamma=10, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'neuralucb':
    strategy = NeuralUCBStrategy(K, get_net, train_X, train_Y, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'marglik_posthoc':
    strategy = MarglikUCBStrategy(K, get_net, train_X, train_Y, online=False, n_epochs=args.n_epochs, device=DEVICE)
elif args.strategy == 'marglik_online':
    strategy = MarglikUCBStrategy(K, get_net, train_X, train_Y, online=True, n_epochs=args.n_epochs, device=DEVICE)

pbar = tqdm.tqdm(dataloader, total=args.T) if not args.silent else dataloader
total_regret = 0  # Total regret

# Log results
gammas, rewards, regrets = [], [], []

for t, (feat, best_action) in enumerate(pbar, start=1):
    feat = feat.squeeze().flatten()
    ctx = utils.features_to_contexts(feat, K)

    # Pull best lever under the posterior
    action = strategy.pull_lever(ctx.to(DEVICE))
    reward = 1 if action == best_action else 0

    # Regret
    regret = 0 if action == best_action else 1
    total_regret += regret

    if args.strategy != 'random':
        gamma_t = strategy.get_gamma(t)
        # Update Laplace every T_retrain iteration
        retrain = (t % args.retrain_every == 0)
        strategy.condition_on_observations(
            ctx[action].reshape(1, -1),
            torch.tensor(reward).float().reshape(1, 1),
            retrain=retrain
        )
        desc_str = f'[t = {t}, total_regret = {total_regret}, gamma_t = {gamma_t:.2f}]'
    else:
        gamma_t = 0
        desc_str = f'[t = {t}, total_regret = {total_regret}]'

    # Log
    gammas.append(gamma_t)
    rewards.append(reward)
    regrets.append(regret)

    if not args.silent:
        pbar.set_description(desc_str)

    if t > args.T:
        break

# Save results
dir_name = f'results/{args.dataset}'
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

results = {'gammas': gammas, 'rewards': rewards, 'regrets': regrets}

if args.strategy == 'neuralucb':
    fname = f'neuralucb_{args.width}{"_ntp" if args.ntk_param else ""}_{args.randseed}.npy'
else:
    fname = f'{args.strategy}_{args.randseed}.npy'

np.save(f'{dir_name}/{fname}', results)
