from dotmap import DotMap
import yaml
import gpytorch
import torch
import os
import random
import numpy as np

from gp_model import VariationalMultitaskGPModel
from envs import gp_sampled, navigation, maze
from plotting import save_data, plot_metrics


ENVS = {
        "gp_sampled": gp_sampled,
        "navigation": navigation,
        "maze": maze,
    }

def seed_everything(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_save_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    cfg = DotMap(config)
    cfg.save_dir = os.path.join(cfg.save_dir, cfg.exp_name)
    os.makedirs(cfg.save_dir, exist_ok=True)
    config_save_path = os.path.join(cfg.save_dir, "config.yaml")
    with open(config_save_path, 'w') as f:
        yaml.dump(cfg.toDict(), f)
    return cfg

def setup_spaces(cfg):
    state_space = torch.linspace(0, 1, cfg.grid_size, device=cfg.device).round(decimals=4)
    state_grid = torch.cartesian_prod(*[state_space for _ in range(cfg.state_dim)])
    if cfg.env == "maze":  # filter out states that are in wall regions
        maze_map = torch.ones((12, 12), device=cfg.device)
        maze_map[2:10, 2:4] = 0
        maze_map[2:4, 2:10] = 0
        maze_map[8:10, 2:10] = 0
        maze_map[2:10, 8:10] = 0

        def continuous_to_grid(pos):
            grid_pos = (pos * (maze_map.shape[0] - 1)).long()
            return torch.clamp(grid_pos, 0, maze_map.shape[0] - 1)

        def is_valid_state(pos):
            grid_indices = continuous_to_grid(pos)
            return maze_map[grid_indices[:, 0], grid_indices[:, 1]] == 0

        valid_mask = is_valid_state(state_grid)
        state_grid = state_grid[valid_mask]

    action_scale = state_space[1] - state_space[0]
    action_grid = torch.tensor(
        [[-1,-1], [-1,0], [-1,1], [0,-1], [0,0], [0,1], [1,-1], [1,0], [1,1]],
        device=cfg.device, dtype=torch.float32
    ) * action_scale
    expanded_states = state_grid.repeat_interleave(action_grid.shape[0], dim=0)
    expanded_actions = action_grid.repeat((state_grid.shape[0], 1))
    sa_grid = torch.cat((expanded_states, expanded_actions), dim=-1)
    sa_index_map = {tuple(sa.tolist()): i for i, sa in enumerate(sa_grid)}
    state_index_map = {tuple(state.tolist()): i for i, state in enumerate(state_grid)}
    return state_grid, action_grid, sa_grid, sa_index_map, state_index_map

def compute_V_opt(cfg, state_grid, action_grid, sa_index_map, f_true_values):
    V_opt = torch.zeros((cfg.episode_length + 1, state_grid.shape[0]), device=cfg.device)
    for h in reversed(range(cfg.episode_length)):
        for i, s in enumerate(state_grid):
            best_value = -float("inf")
            for a in action_grid:
                discrete_sa_idx = sa_index_map[tuple(torch.cat((s, a)).tolist())]
                next_s = f_true_values["next_state"][discrete_sa_idx]
                reward = f_true_values["reward"][discrete_sa_idx]
                discrete_next_s_idx = torch.argmin(torch.sum((state_grid - next_s) ** 2, dim=1))
                value = reward + cfg.gamma * V_opt[h + 1, discrete_next_s_idx]
                best_value = max(best_value, value)
            V_opt[h, i] = best_value
    return V_opt

def train_gp_thompson_sampling(
        cfg,
        state_grid,
        action_grid,
        sa_grid,
        sa_index_map,
        state_index_map,
        f_true_values,
        V_opt
    ):

    # Initialize new GP model for modeling reward-dynamics function
    model = VariationalMultitaskGPModel(
        kernel=cfg.kernel,
        input_dim=cfg.state_dim + cfg.action_dim,
        num_models=cfg.state_dim + 1,
        num_inducing_points=cfg.num_inducing_points,
        use_coregionalization=cfg.use_coregionalization,
        overwrite_lmc_coeffs=True if cfg.env == "fixed_gp" else False,
    ).to(cfg.device)
    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
        num_tasks=cfg.state_dim + 1,
        rank=cfg.state_dim + 1 if cfg.use_full_likelihood_rank else 0,
    ).to(cfg.device)
    optimizer = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=cfg.gp_optim_lr)
    model.eval()
    likelihood.eval()
    
    cumulative_regret = torch.zeros(cfg.num_episodes, device=cfg.device)
    regret = torch.zeros(cfg.num_episodes, device=cfg.device)
    episode_return = torch.zeros(cfg.num_episodes, device=cfg.device)
    gp_loss = torch.zeros(cfg.num_episodes, device=cfg.device)
    for episode in range(cfg.num_episodes):

        # Sample a reward-dynamics model from the GP posterior
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            f_sample = likelihood(model(sa_grid)).sample()
        f_next_states = f_sample[:, :-1]
        f_rewards = f_sample[:, -1]

        # Compute Q-values using vectorized value iteration
        next_s_discrete = (
            torch.round(torch.clamp(f_next_states, 0, 1) * (cfg.grid_size - 1)) / (cfg.grid_size - 1)
        ).round(decimals=4)
        key_list = [tuple(next_s.tolist()) for next_s in next_s_discrete]
        if cfg.env == "maze":
            discrete_next_s_idx = []
            for i, k in enumerate(key_list):
                if k in state_index_map:
                    discrete_next_s_idx.append(state_index_map[k])
                else:
                    s = sa_grid[i, :cfg.state_dim]
                    s_rounded = tuple(((torch.round(s * (cfg.grid_size - 1)) / (cfg.grid_size - 1)).round(decimals=4)).tolist())
                    discrete_next_s_idx.append(state_index_map[s_rounded])
            discrete_next_s_idx = torch.tensor(discrete_next_s_idx, device=cfg.device)
        else:
            discrete_next_s_idx = torch.tensor([state_index_map[k] for k in key_list], device=cfg.device)
        V_hat = torch.zeros((cfg.episode_length + 1, state_grid.shape[0]), device=cfg.device)
        Q_hat = torch.full((cfg.episode_length, state_grid.shape[0], action_grid.shape[0]), -float("inf"), device=cfg.device)
        for h in reversed(range(cfg.episode_length)):
            V_next = V_hat[h + 1, discrete_next_s_idx]
            Q_vals = f_rewards + cfg.gamma * V_next
            Q_hat[h] = Q_vals.view(state_grid.shape[0], action_grid.shape[0])
            V_hat[h] = Q_hat[h].max(dim=1).values

        # Roll out the policy for an episode
        V_hat_pi = 0
        discount = 1
        initial_state_idx = torch.randint(0, state_grid.shape[0], (1,)).item()
        s_idx = initial_state_idx
        episode_inputs = []
        episode_targets = []
        for h in range(cfg.episode_length):
            a_idx = torch.argmax(Q_hat[h, s_idx]).item()
            sa = torch.cat((state_grid[s_idx], action_grid[a_idx]))
            discrete_sa_idx = sa_index_map[tuple(sa.tolist())]
            next_s = f_true_values["next_state"][discrete_sa_idx]
            reward = f_true_values["reward"][discrete_sa_idx]
            episode_return[episode] += reward
            V_hat_pi += discount * reward
            discount *= cfg.gamma
            episode_inputs.append(sa.unsqueeze(0))
            episode_targets.append(torch.cat((next_s, reward.unsqueeze(0))).unsqueeze(0))
            next_s_discrete = torch.round(torch.clamp(next_s, 0, 1) * (cfg.grid_size - 1)) / (cfg.grid_size - 1)
            s_idx = state_index_map[tuple(next_s_discrete.round(decimals=4).tolist())]
        episode_inputs = torch.cat(episode_inputs, dim=0)
        episode_targets = torch.cat(episode_targets, dim=0)

        # Compute regret
        regret[episode] = V_opt[0, initial_state_idx] - V_hat_pi
        cumulative_regret[episode] = cumulative_regret[episode - 1] + regret[episode] if episode > 0 else regret[episode]

        # Update GP model
        iteration = 0
        train_inputs = torch.cat((train_inputs, episode_inputs), dim=0) if episode > 0 else episode_inputs
        train_targets = torch.cat((train_targets, episode_targets), dim=0) if episode > 0 else episode_targets
        loss_function = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_inputs.shape[0])
        model.train()
        likelihood.train()
        while iteration < cfg.max_gp_updates:
            optimizer.zero_grad()
            output = model(train_inputs)
            loss = -loss_function(output, train_targets)
            loss.backward()
            optimizer.step()
            iteration += 1
        model.eval()
        likelihood.eval()
        gp_loss[episode] = loss.item()

    return cumulative_regret, regret, episode_return, gp_loss

def train():
    cfg = load_save_config("config.yaml")
    print(f"Running experiment: {cfg.exp_name}")

    # Discretize state and action space
    state_grid, action_grid, sa_grid, sa_index_map, state_index_map = setup_spaces(cfg)

    # Initialize the environment - reward and dynamics structure
    seed_everything(cfg.seed)

    f_true_values = ENVS[cfg.env](cfg, sa_grid)
    print(f"Running value iteration to compute V_opt")
    V_opt = compute_V_opt(cfg, state_grid, action_grid, sa_index_map, f_true_values)

    # Run each trial
    cumulative_regret = torch.zeros((cfg.num_trials, cfg.num_episodes), device=cfg.device)
    regret = torch.zeros((cfg.num_trials, cfg.num_episodes), device=cfg.device)
    episode_return = torch.zeros((cfg.num_trials, cfg.num_episodes), device=cfg.device)
    gp_loss = torch.zeros((cfg.num_trials, cfg.num_episodes), device=cfg.device)
    for t in range(cfg.num_trials):
        print(f"Running trial {t}")

        if cfg.env == "gp_sampled":
            # Sample a new environment from the same GP
            f_true_values = ENVS[cfg.env](cfg, sa_grid)
            V_opt = compute_V_opt(cfg, state_grid, action_grid, sa_index_map, f_true_values)

        # Perform Thompson sampling
        cumulative_regret[t], regret[t], episode_return[t],gp_loss[t] = train_gp_thompson_sampling(
            cfg, state_grid, action_grid, sa_grid, sa_index_map, state_index_map, f_true_values, V_opt
        )

        save_data(cfg, t + 1, cumulative_regret, regret, episode_return, gp_loss)
        plot_metrics(cfg, t + 1, cumulative_regret, regret, episode_return, gp_loss)


if __name__ == "__main__":
    train()