import os
from rvs import dataset, step
from torch import nn
import torch
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
from tqdm import tqdm
from rvs.util import SegmentDataset,\
    get_trajs, \
    return_labels, \
    FastTensorDataLoader
from wandb.sdk.wandb_run import Run


class QuantileReturnPredictor(nn.Module):
    def __init__(self, obs_size, hidden_size=512, num_q=20, ens_size=5, obs_mean=0., obs_std=1., value_scale=1.):
        super().__init__()
        self.num_q = num_q
        self.models = nn.ModuleList()
        self.obs_mean = obs_mean.cuda().float()
        self.obs_std = obs_std.cuda().float()
        self.value_scale = value_scale
        for i in range(ens_size):
            self.models.append(nn.Sequential(nn.Linear(obs_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size),
                                             nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, num_q)))

    def forward(self, obs):
        obs = (obs - self.obs_mean) / self.obs_std
        obs = obs.reshape(obs.shape[0], -1)
        x = [m(obs) * self.value_scale for m in self.models]
        return torch.stack(x, dim=-1)


def huber_quantile(x, k=1.):
    if np.allclose(0, k):
        return x.abs()
    else:
        return torch.where(x.abs() < k, 0.5 * x.pow(2) / k, k * (x.abs() - 0.5 * k))


def top_quantile_returns(trajs, ret_model, value_range=(-np.inf, np.inf), value_fn=False, last_n_quantiles=5, ensemble='mean'):
    ret_model.eval()
    rets = []
    with torch.no_grad():
        for traj in tqdm(trajs):
            obs = torch.tensor(np.array(traj.obs)).cuda()
            if ensemble == 'mean':
                if value_fn:
                    model_output = ret_model(obs)
                    rs = torch.clamp(model_output, value_range[0], value_range[1]).mean(-1)[
                        :, -last_n_quantiles:].mean(-1).cpu().numpy()
                else:
                    act = torch.tensor(
                        np.array(traj.continuous_actions)).cuda()
                    model_output = ret_model(torch.cat((obs, act), dim=-1))
                    rs = torch.clamp(model_output, value_range[0], value_range[1]).mean(-1)[
                        :, -last_n_quantiles:].mean(-1).cpu().numpy()
            elif ensemble == 'min':
                if value_fn:
                    model_output = ret_model(obs)
                    rs = torch.clamp(model_output, value_range[0], value_range[1]).min(-1)[0][
                        :, -last_n_quantiles:].mean(-1).cpu().numpy()
                else:
                    act = torch.tensor(
                        np.array(traj.continuous_actions)).cuda()
                    model_output = ret_model(torch.cat((obs, act), dim=-1))
                    rs = torch.clamp(model_output, value_range[0], value_range[1]).min(-1)[0][
                        :, -last_n_quantiles:].mean(-1).cpu().numpy()
            else:
                raise
            rets.append(rs)
    return rets


def qr_augmented_return_labels(traj, traj_sampled_rets, traj_labels, discount_factor, initial_ret=(lambda _: 0), value_fn=False):
    rewards = traj.rewards
    returns = []
    ret = initial_ret(traj)
    traj_len = len(rewards)

    for i in reversed(range(traj_len)):
        ret *= discount_factor
        ret += (float(rewards[i]))
        if value_fn and (i + 1 < traj_len):
            ret = max(ret, float(rewards[i]) +
                      discount_factor * traj_sampled_rets[i + 1])
        else:
            ret = max(ret, traj_sampled_rets[i])
        returns.append(ret)
    returns = list(reversed(returns))
    return returns


def augmented_return_labels(traj, traj_sampled_rets, traj_labels, discount_factor, value_fn=False, initial_ret=(lambda _: 0)):
    rewards = traj.rewards
    returns = []
    ret = initial_ret(traj)
    traj_len = len(rewards)
    max_labels = np.array(traj_labels).max(axis=0)
    diff = traj_sampled_rets - max_labels
    augment_idx = np.argmax(diff)

    for i in reversed(range(traj_len)):
        ret *= discount_factor
        ret += float(rewards[i])
        if value_fn and (i + 1 < traj_len):
            if i + 1 == augment_idx:
                ret = max(ret, float(rewards[i]) +
                          discount_factor * traj_sampled_rets[i + 1])
        else:
            if i == augment_idx:
                ret = max(ret, traj_sampled_rets[i])
        returns.append(ret)
        # if i == augment_idx:
        #     ret = max(ret, traj_sampled_rets[i])

    returns = list(reversed(returns))
    return returns


def train_quantile_return_predictor(train_dataset,
                                    learning_rate,
                                    batch_size,
                                    epochs,
                                    hidden_size=512,
                                    learning_rate_scheduler=None,
                                    wandb_run=None,
                                    value_fn=False,
                                    huber_k=1.,
                                    num_q=20,
                                    value_range=(-np.inf, np.inf),
                                    value_scale=1.,
                                    store_dataset_gpu=False,
                                    ):
    obs, act, lab = next(iter(train_dataset))
    if value_fn:
        obs_size = np.prod(obs.shape)
        obs_mean, obs_std = train_dataset.mean_std()
        obs_mean = obs_mean[:-1]
        obs_std = obs_std[:-1]
    else:
        obs_size = np.prod(obs.shape) + np.prod(act.shape)
        obs_mean, obs_std = train_dataset.mean_std()
        obs_mean = obs_mean[:-1]
        obs_std = obs_std[:-1]
        act_mean, act_std = train_dataset.mean_std_acts()
        obs_mean = torch.cat((obs_mean, act_mean))
        obs_std = torch.cat((obs_std, act_std))

    model = QuantileReturnPredictor(
        obs_size, hidden_size=hidden_size, num_q=num_q, obs_mean=obs_mean, obs_std=obs_std, value_scale=value_scale).cuda()
    tau = torch.Tensor((2 * torch.arange(num_q) + 1) /
                       (2.0 * num_q)).reshape(1, -1, 1).cuda()

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    if learning_rate_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    elif learning_rate_scheduler == 'linear':
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda e: 1 - e / epochs)
    else:
        scheduler = None

    # Get num cpu cores available on slurm
    num_workers = int(os.environ.get('SLURM_CPUS_PER_TASK', 2))
    print(f'Using {num_workers} workers')
    if store_dataset_gpu:
        train_dataset = train_dataset.convert_to_tensor_dataset()
        train_dataloader = FastTensorDataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, device='cuda')
    else:
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, num_workers=num_workers)

    model.train()
    step = 0

    for epoch in range(epochs):
        loss_avg = 0.0

        pbar = tqdm(enumerate(train_dataloader), total=int(
            len(train_dataloader)))
        for i, (obs, act, lab) in pbar:
            obs = obs.cuda()
            lab = lab.cuda()

            if value_fn:
                pred_ret = model(obs)
            else:
                act = act.cuda()
                pred_ret = model(torch.cat((obs, act), dim=-1))

            diff = lab[:, None, None] - pred_ret
            loss = huber_quantile(diff, k=huber_k) * \
                (tau - (diff.detach() < 0).float()).abs()
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_avg += loss.item()
            pbar.set_description(
                f'Epoch {epoch} | Loss: {loss_avg / (i + 1):.4f}')

        if scheduler is not None:
            scheduler.step()
        if wandb_run:
            step += 1
            wandb_run.log({'bootstrap_loss': loss_avg /
                           (i + 1), 'bootstrap_epoch': epoch})

    return model


def bootstrapped_dataset_quantile(dataset: Dict[str, np.ndarray],
                                  discount_factor: float,
                                  model_args: Dict[str, Any],
                                  n_iter: int,
                                  seed: int,
                                  wandb_run: Optional[Run] = None,
                                  env=None,
                                  store_dataset_gpu=False,):
    """
      Create a Segment dataset, then iteratively train a return model and bootstrap returns.
      """
    hidden_size = model_args['hidden_size']
    learning_rate = model_args['learning_rate']
    learning_rate_scheduler = model_args['learning_rate_scheduler']
    batch_size = model_args['batch_size']
    epochs = model_args['epochs']
    reward_preprocessing = model_args['reward_preprocessing']
    value_fn = model_args["value_fn"]
    last_n_quantiles = model_args["quantiles"]
    ensemble = model_args["ensemble"]

    rewards = dataset["rewards"]
    value_range = (-np.inf, np.inf)

    if discount_factor == 1:
        value_range = (-np.inf, np.inf)

    print(value_range)

    trajs = get_trajs(dataset)
    # print(len(trajs))
    #trajs = [t for t in trajs if len(t.obs) > 1]
    # print(len(trajs))

    assert env is not None
    if step.is_antmaze_env(env) and reward_preprocessing == "antmaze":
        value_range = (rewards.min() / (1 - discount_factor + 1e-8),
                      rewards.max() / (1 - discount_factor + 1e-8))
        initial_ret = lambda traj: 0 if traj.rewards[-1] > - \
            0.5 else max(-1000., -1 / (1 - discount_factor + 1e-8))
    elif reward_preprocessing == "conservative" and discount_factor <= 0.99:
        initial_ret = lambda _: value_range[0]
    else:
        initial_ret = lambda _: 0

    label_fn = lambda traj: return_labels(traj, discount_factor, initial_ret)

    seg_dataset = SegmentDataset(trajs, label_fn)
    seg_dataset.seed = seed
    seg_dataset.continuous_actions = True

    # Get the min and max return in the dataset
    rets = np.concatenate([label_fn(traj) for traj in trajs])
    ret_min, ret_max = rets.min(), rets.max()
    value_scale = ret_max - ret_min
    print("ret_min: {}, ret_max: {}".format(ret_min, ret_max))

    for j in range(n_iter):
        if model_args["only_sample_last"]:
            seg_dataset.label_to_sample = j

        # Train a generative model (basically classify which integer rewards are possible at each state)
        return_model = train_quantile_return_predictor(seg_dataset, learning_rate, batch_size, epochs, hidden_size,
                                                       learning_rate_scheduler, wandb_run, value_fn=value_fn, value_range=value_range, value_scale=value_scale, store_dataset_gpu=store_dataset_gpu)

        # Sample from the generative model for each state in our dataset
        sampled_rets = top_quantile_returns(
            seg_dataset.trajs, return_model, value_range=value_range, value_fn=value_fn, last_n_quantiles=last_n_quantiles, ensemble=ensemble)

        # Augment our dataset with some bootstrapped return-to-gos
        for i, traj_labels in enumerate(seg_dataset.traj_labels):

            if model_args["relabel_style"] == 'greedy':
                augmented_label = qr_augmented_return_labels(
                    seg_dataset.trajs[i], sampled_rets[i], traj_labels, discount_factor, initial_ret, value_fn=value_fn)
            elif model_args["relabel_style"] == 'singlepoint':
                augmented_label = augmented_return_labels(
                    seg_dataset.trajs[i], sampled_rets[i], traj_labels, discount_factor, initial_ret)
            else:
                raise
            traj_labels.append(augmented_label)
            seg_dataset.traj_labels[i] = traj_labels

        seg_dataset.cache_tensors()

    seg_dataset.label_to_sample = n_iter

    # Save the final model
    # Train a generative model(basically classify which integer rewards are possible at each state)
    return_model = train_quantile_return_predictor(seg_dataset, learning_rate, batch_size, epochs, hidden_size,
                                                   learning_rate_scheduler, wandb_run, value_fn=True, value_range=value_range, value_scale=value_scale, store_dataset_gpu=store_dataset_gpu)
    filename = os.path.join(wandb_run.dir, 'return_model.pkl')
    print("Saving final return model to", filename)
    torch.save(return_model, filename)
    wandb_run.save(filename)

    if model_args["only_sample_last_policy"]:
        seg_dataset.label_to_sample = n_iter
    else:
        seg_dataset.label_to_sample = None
    seg_dataset.continuous_actions = False

    return seg_dataset
