import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import os
import random
import numpy as np
import matplotlib.pyplot as plt
np.warnings.filterwarnings('ignore')

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter

from cqr.datasets import datasets
from models import LinearModel, MLP, PinballLoss
from utils import estimate_sigma, generate_pseudo_labels

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", default="./cqr/datasets/", type=str,
                    help="Root path for all datasets")
parser.add_argument("--dataset", default="community", type=str,
                    help="Name of dataset (blog_data|bike|bio|community|concrete|star|facebook_{1,2}|meps_{19,20,21}")
parser.add_argument("--batch_size", default=64, type=int,
                    help="batch size")
parser.add_argument("--kappa_subset", default=None, type=float)
parser.add_argument("--n", default=5000, type=int)
parser.add_argument("--d", default=1000, type=int)
parser.add_argument("--sigma_z", default=0.5, type=float)
parser.add_argument("--seed", default=2, type=int,
                    help="random seed")
parser.add_argument("--test_ratio", default=0.2, type=float,
                    help="ratio for test split")
parser.add_argument("--alpha", default=0.9, type=float,
                    help="desired quantile level")
parser.add_argument("--model", default="linear", type=str,
                    help="model (linear|mlp)")
parser.add_argument("--depth", default=100, type=int)
parser.add_argument("--width", default=64, type=int)
parser.add_argument("--freeze_reps", action='store_true',
                    help="freeze representations in MLP (only train top linear layer)")
parser.add_argument("--pseudo_labels", action="store_true",
                    help="generate pseudo-labels from linear models and retrain a linear model")
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--epochs", default=900, type=int)
parser.add_argument("--decay_factor", default=0.1, type=float)
parser.add_argument("--decay_per_epoch", default=300, type=int)

parser.add_argument("--lr_pseudo", default=1e-3, type=float)
parser.add_argument("--momentum_pseudo", default=0.9, type=float)
parser.add_argument("--epochs_pseudo", default=900, type=int)

parser.add_argument("--disp_per_epoch", default=50, type=int)
parser.add_argument("--save_path", default='./runs', type=str,
                    help="path for saving results")
args = parser.parse_args()

random_state_train_test = args.seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
    torch.cuda.manual_seed_all(args.seed)

# create tensorboard
if args.model == 'linear':
    kappa_tag = f"kappa={args.kappa_subset}|" if args.kappa_subset is not None else ""
    if args.pseudo_labels:
        name = f"linear_pseudo|{kappa_tag}alpha={args.alpha}|seed={args.seed}"
    else:
        name = f"linear|{kappa_tag}alpha={args.alpha}|seed={args.seed}"
elif args.model == 'mlp':
    freeze_flag = "_freeze" if args.freeze_reps else ""
    name = f"mlp{freeze_flag}_depth={args.depth}_hidden={args.width}|alpha={args.alpha}|seed={args.seed}"
exp_path = os.path.join(args.save_path, args.dataset, name)
writer = SummaryWriter(logdir=exp_path)


if args.dataset != "linear":
    # load the dataset
    X, y = datasets.GetDataset(args.dataset, args.data_path)
    # Standardize x and y
    x_mean, x_std = np.mean(X, axis=0, keepdims=True), np.std(X, axis=0, keepdims=True)
    X = (X - x_mean) / x_std
    y_mean, y_std = np.mean(y), np.std(y)
    y = (y - y_mean) / y_std
else:
    # Synthetic data from linear model
    # n, d = 5000, 1000
    # sigma_z = 0.5
    w_star = np.random.randn(args.d)
    w_star /= np.linalg.norm(w_star)
    X = np.random.randn(args.n, args.d)
    y = np.dot(X, w_star) + args.sigma_z * np.random.randn(args.n)

# used to determine the size of test set
# test_ratio = 0.2
# divide the dataset into test and train based on the test_ratio parameter
x_train, x_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=args.test_ratio,
                                                    random_state=random_state_train_test)
# reshape the data
x_train = np.asarray(x_train)
y_train = np.asarray(y_train)
x_test = np.asarray(x_test)
y_test = np.asarray(y_test)

# optionally take a subset of the training set
if args.kappa_subset is not None:
    n_subset = np.minimum(int(x_train.shape[1] / args.kappa_subset), x_train.shape[0])
    perms = np.random.permutation(x_train.shape[0])
    inds_train, inds_val = perms[:n_subset], perms[n_subset:]
    x_val, y_val = x_train[inds_val], y_train[inds_val]
    x_train, y_train = x_train[inds_train], y_train[inds_train]

# compute input dimensions
n_train = x_train.shape[0]
in_shape = x_train.shape[1]

# display basic information
print("Dataset: %s" % args.dataset)
print("Dimensions: train set (n=%d, p=%d) ; test set (n=%d, p=%d)" %
      (x_train.shape[0], x_train.shape[1], x_test.shape[0], x_test.shape[1]))
writer.add_scalar("summary/n", x_train.shape[0])
writer.add_scalar("summary/d", x_train.shape[1])
writer.add_scalar("alpha", args.alpha)

# create dataloaders
train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
if args.kappa_subset is not None:
    val_dataset = TensorDataset(torch.Tensor(x_val), torch.tensor(y_val))


def train(epoch, net, criterion, verbose=False):
    if verbose:
        print('\nEpoch: %d' % epoch)
    train_loss = 0.0
    count = 0
    covered = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs).squeeze()
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.shape[0]
        count += inputs.shape[0]
        covered += (targets <= outputs).sum().float().item()
        if verbose and (batch_idx+1) % 10 == 0:
            print(f"Batch [{batch_idx+1}/{len(train_loader)}]: "
                  f"Loss: {train_loss/count:.6f}, Coverage: {100.*covered/count:.3f}")
    return train_loss / count, 1.*covered/count


def test(epoch, net, criterion, verbose=False):
    test_loss = 0.0
    avg_f = 0.0
    count = 0
    covered = 0.0
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs).squeeze()
            loss = criterion(outputs, targets)
            optimizer.step()
            test_loss += loss.item() * inputs.shape[0]
            avg_f += outputs.mean().item() * inputs.shape[0]
            count += inputs.shape[0]
            covered += (targets <= outputs).sum().float().item()
            if verbose and (batch_idx+1) % 10 == 0:
                print(f"Batch [{batch_idx+1}/{len(test_loader)}]: "
                      f"Loss: {test_loss/count:.6f}, Coverage: {100.*covered/count:.3f}")
    return test_loss / count, 1.*covered/count, avg_f/count


# Construct model
in_dim = x_train.shape[1]
if args.model == "linear":
    net = LinearModel(in_dim).to(device)
elif args.model == "mlp":
    net = MLP(in_dim, depth=args.depth, hidden_dim=args.width, freeze_reps=args.freeze_reps).to(device)


# Generate pseudo-labels
if args.pseudo_labels:
    net_pseudo = LinearModel(in_dim).to(device)
    optimizer = optim.SGD(net_pseudo.parameters(), lr=args.lr_pseudo, momentum=args.momentum_pseudo)
    lambda1 = lambda ep: np.power(args.decay_factor, ep // args.decay_per_epoch)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    criterion = nn.MSELoss(reduction="mean")

    train_losses = list()
    test_losses = list()
    for epoch in range(args.epochs_pseudo):
        # train for one epoch
        train_loss, _ = train(epoch, net_pseudo, criterion)
        train_losses.append(train_loss)
        # test
        test_loss, _, _ = test(epoch, net_pseudo, criterion)
        test_losses.append(test_loss)
        if (epoch + 1) % args.disp_per_epoch == 0:
            print(f"\nEpoch [{epoch + 1}]")
            print(f"Training loss = {train_loss:.4f}")
            print(f"Test loss = {test_loss:.4f}")
        scheduler.step()

    # rewrite dataset with pseudo-labels
    sigma = estimate_sigma(val_dataset, net_pseudo, device)
    print(f"Estimated sigma={sigma:.6f}")
    train_dataset = generate_pseudo_labels(train_dataset, net_pseudo, device, sigma=sigma)
    test_dataset = generate_pseudo_labels(test_dataset, net_pseudo, device, sigma=sigma)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    print("Pseudo labels generated.")


# Training
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
lambda1 = lambda ep: np.power(args.decay_factor, ep // args.decay_per_epoch)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
criterion = PinballLoss(quantile=args.alpha, reduction="mean")

train_losses, train_covs = list(), list()
test_losses, test_covs = list(), list()
for epoch in range(args.epochs):
    # train for one epoch
    train_loss, train_cov = train(epoch, net, criterion)
    writer.add_scalar("loss/train", train_loss, epoch + 1)
    writer.add_scalar("coverage/train", train_cov, epoch + 1)
    train_losses.append(train_loss)
    train_covs.append(train_cov)
    # test
    test_loss, test_cov, test_avg_f = test(epoch, net, criterion)
    writer.add_scalar("loss/test", test_loss, epoch + 1)
    writer.add_scalar("coverage/test", test_cov, epoch + 1)
    writer.add_scalar("avg_f/test", test_avg_f, epoch + 1)
    test_losses.append(test_loss)
    test_covs.append(test_cov)
    if (epoch+1) % args.disp_per_epoch == 0:
        print(f"\nEpoch [{epoch+1}]")
        print(f"Training loss = {train_loss:.4f}, Training coverage = {train_cov:.4f}")
        print(f"Test loss = {test_loss:.4f}, Test coverage = {test_cov:.4f}, Test avg_f = {test_avg_f:.4f}")
    scheduler.step()
