import numpy as np
import torch
from dataclasses import dataclass
import pickle as pkl
import os

from tqdm import trange, tqdm
import absl.app
import absl.flags

from .frozen_lake import FrozenLakeEnv, frozen_lake_env_from_string, frozen_lake_policy_from_string
from .policies import TabularPolicy, LinearPolicy, POPLinearPolicy, EpsilonGreedyPolicy
from .q_models import LinearQModel, POPQ, LinearGModel, CQL

from JaxRL.utils import define_flags_with_default, set_random_seed, get_user_flags


def sample_trajectories(env, policy, n=1000, progress_bar=False):
    """Sample trajectories from the environment following the given policy."""
    states = np.zeros([n + 1], dtype=np.int32)
    actions = np.zeros([n], dtype=np.int32)
    rewards = np.zeros([n], dtype=np.float32)
    dones = np.zeros([n], dtype=bool)
    states[0] = env.reset()
    for i in trange(n, desc='Sampling Trajectories', disable=not progress_bar):
        actions[i] = policy.act(states[i])
        states[i + 1], rewards[i], done, _ = env.step(actions[i])
        if done:
            dones[i] = True
            states[i + 1] = env.reset()
    return states, actions, rewards, dones


@dataclass(frozen=True)
class FrozenLakeDataset:
    s: np.ndarray
    a: np.ndarray
    sp: np.ndarray
    r: np.ndarray
    done: np.ndarray

    def to_torch(self, device=None):
        return FrozenLakeDataset(
            s=torch.from_numpy(self.s).to(dtype=torch.int64, device=device),
            a=torch.from_numpy(self.a).to(dtype=torch.int64, device=device),
            sp=torch.from_numpy(self.sp).to(dtype=torch.int64, device=device),
            r=torch.from_numpy(self.r).to(dtype=torch.float32, device=device),
            done=torch.from_numpy(self.done).to(dtype=torch.bool, device=device),
        )


def generate_dataset(env, policy, n=1000, progress_bar=False):
    """Generate a dataset from the environment following the given policy."""
    states, actions, rewards, dones = sample_trajectories(env, policy, n, progress_bar=progress_bar)
    return FrozenLakeDataset(
        s=states[:-1],
        a=actions,
        sp=states[1:],
        r=rewards,
        done=dones,
    )


frzmap = [
    "FFFS",
    "FHFH",
    "FFFF",
    "FFHG"
]

env = frozen_lake_env_from_string(
    frzmap, loop=True, slippery=0.1
)
num_states = env.observation_space.n
num_actions = env.action_space.n
R = env.get_reward_matrix().reshape((num_states * num_actions, 1))

# Generate Data distribution
data_policy = frozen_lake_policy_from_string(
    [
        "↓←←←",
        "↓↑↑↑",
        "↓→→↓",
        "→↑↑↑",
    ],
    epsilon=.5,
)

print("Data Policy and Distribution:")
print(env.render_policy(data_policy))

performance_env = frozen_lake_env_from_string(
    frzmap, loop=False, slippery=0.1
)

def get_P_policy(policy, P):
    # Compute the (S x A) x (S x A) transition matrix for the given policy using the S x A x S transition matrix.
    num_states, num_actions, _ = P.shape
    n = num_states * num_actions

    # P_policy = np.zeros([n, n])
    # for s in range(num_states):
    #     for a in range(num_actions):
    #         idx = s*num_actions + a
    #         for sp in range(num_states):
    #             idx_p = sp*num_actions
    #             P_policy[idx, idx_p:idx_p+num_actions] = P[s, a, sp] * policy.dist(sp)

    if isinstance(P, np.ndarray):
        policy_table = policy.dist(np.arange(num_states))
    else:
        policy_table = policy.dist(torch.arange(num_states, device=P.device, dtype=torch.int64))
    P_policy = (P[:, :, :, None] * policy_table[None, None, :, :]).reshape([n, n])

    # assert np.allclose(P_policy, P_policy2)

    return P_policy


def compute_F_s(policy, P, Phi):
    # Compute F(s) = E_{s' ~ P(s'|s)}[F(s, s')] for the given policy and transition matrix.

    P_policy = get_P_policy(policy, P)

    A = Phi[:, :, None] @ Phi[:, None, :]
    B = Phi[:, :, None] @ (P_policy @ Phi)[:, None, :]

    # return np.concatenate([np.concatenate([A, B], axis=2), np.concatenate([B.transpose(0, 2, 1), A], axis=2)], axis=1)
    return torch.concatenate([torch.cat([A, B], dim=2), torch.cat([B.transpose(2, 1), A], dim=2)], dim=1)

def compute_expected_F(mu, policy, P, Phi):
    # Compute $$E_{s ~ \mu}[F(s)]$$ for the given state-action distribution, $$\mu$$, and policy.

    return (mu[:, None, None] * compute_F_s(policy, P, Phi)).sum(axis=0)


def get_policy_distribution(env, policy):
    P = get_P_policy(policy, env.get_transition_matrix())
    evals, evecs = np.linalg.eig(P.T)
    mu = np.real(evecs[:, np.argmax(np.real(evals))])
    return mu / mu.sum()


def visit_map(env, dataset):
    counts = np.bincount(dataset.s, minlength=env.observation_space.n)
    return counts.reshape(env.map.shape)


def get_q_table(q_model):
    q_table = np.zeros((env.observation_space.n, env.action_space.n))
    for s in range(env.observation_space.n):
        q_table[s] = q_model(s)
    return q_table


# Compute the TD-error:
def compute_TD_error(env, q_table, policy=None, gamma=0.99):
    target_q = np.zeros((env.observation_space.n, env.action_space.n))
    R = env.get_reward_matrix()
    P = env.get_transition_matrix()
    for s in range(env.observation_space.n):
        for a in range(env.action_space.n):
            target_q[s, a] = R[s, a]
            for sp in range(env.observation_space.n):
                if policy is None:
                    target_q[s, a] += P[s, a, sp] * gamma * np.max(q_table[sp])
                else:
                    target_q[s, a] += P[s, a, sp] * gamma * np.sum(q_table[sp] * policy.dist(sp))
    return np.sum((target_q - q_table) ** 2)


def compute_Q_error(env, dataset, reference_q_table, q_table):
    idx = dataset.s * env.action_space.n + dataset.a
    counts = np.bincount(idx, minlength=env.observation_space.n * env.action_space.n)
    dist = counts / np.sum(counts)
    q_errors = (reference_q_table - q_table) ** 2
    return np.sum(q_errors.flatten() * dist)


# Compute the performance of the policy:
def compute_performance(env, policy, atol=1e-6, max_iters=10000):
    start_s = env.reset()

    P = env.get_transition_matrix()
    R = env.get_reward_matrix()
    done = env.get_terminal_matrix()
    q_table = np.zeros((env.observation_space.n, env.action_space.n))
    q_table_prime = np.zeros_like(q_table)

    for i in range(max_iters):
        q_table_prime *= 0
        for s in range(env.observation_space.n):
            for a in range(env.action_space.n):
                for sp in range(env.observation_space.n):
                    v_next = (policy.dist(sp) * q_table[sp]).sum()
                    q_table_prime[s, a] += P[s, a, sp] * (R[s, a] + (1 - done[s, a]) * v_next)
        if np.allclose(q_table, q_table_prime, atol=atol) or np.isnan(q_table_prime).any():
            break
        q_table = q_table_prime.copy()

    if np.isnan(q_table).any():
        return -np.inf
    else:
        return (q_table[start_s].flatten() * policy.dist(start_s).flatten()).sum()


def dataset_occupancy(env, dataset):
    """Compute the occupancy of each state action pair in the dataset."""
    idx = dataset.s * env.action_space.n + dataset.a
    counts = np.bincount(idx, minlength=env.observation_space.n * env.action_space.n)
    return counts / np.sum(counts)


default_train_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def train(
    q_model,
    policy,
    dataset=None,
    iters=1024,
    batch_size=256,
    gamma=0.99,
    progress_bar=False,
    log_freq=int(1e3),
    dataset_update_freq=int(1e2),
    policy_update_freq=1,
    policy_update_iters=1,
    secondary_update_freq=-1,
    secondary_update_iters=1,
    bc=False,
    device=None,
):
    log = {
        'step': [],
        'TD-error': [],
        'q_loss': [],
        'Q-error': []
    }
    if hasattr(q_model, 'g_model'):
        log['g_loss'] = []
        log['pop_obj'] = []

    if device is None:
        device = default_train_device

    online = dataset is None

    def sample_batch(d):
        # sample mini-batch
        idx = torch.randint(len(d.s), (batch_size,))

        s = d.s[idx,...]
        a = d.a[idx,...]
        sp = d.sp[idx,...]
        r = d.r[idx,...]
        done = d.done[idx,...]
        ap = policy.act(sp)

        return s, a, r, sp, ap, done

    if not online:
        data_dist = dataset_occupancy(env, dataset)
        torch_dataset = dataset.to_torch(device=device)
    elif bc:
        raise ValueError("Behavior cloning not implemented for online training.")

    P = torch.tensor(env.get_transition_matrix(), device=device, dtype=torch.float32)
    update_info = {}
    secondary_update_info = {}
    policy_update_info = {}
    for step in trange(iters, desc=f'Training {"online" if online else "offline"}', disable=not progress_bar):
        if online and step % dataset_update_freq == 0:
            dataset = generate_dataset(env, policy, n=int(batch_size*dataset_update_freq), progress_bar=False)
            data_dist = get_policy_distribution(env, policy)
            torch_dataset = dataset.to_torch(device=device)

        if secondary_update_freq > 0 and step % secondary_update_freq == 0:
            for step2 in range(secondary_update_iters):
                s, a, r, sp, ap, done = sample_batch(torch_dataset)
                secondary_update_info = q_model.secondary_update(s, a, r, sp, ap, done, gamma=gamma)

            q_model.update_reweight()
            F_s = compute_F_s(policy, P, q_model.Phi).detach().cpu().numpy()
            F_mu = (data_dist[:, None, None] * F_s).sum(0)
            q = q_model.reweight.detach().cpu().numpy() * data_dist
            q = q / q.sum()
            F_q = (q[:, None, None] * F_s).sum(0)
            secondary_update_info['F_mu_min_eig'] = np.linalg.eigvalsh(F_mu).min()
            secondary_update_info['F_q_min_eig'] = np.linalg.eigvalsh(F_q).min()

        s, a, r, sp, ap, done = sample_batch(torch_dataset)
        if not bc:
            update_info = q_model.update(s, a, r, sp, ap, done, gamma=gamma)

        if policy_update_freq == 1 or (step + 1) % policy_update_freq == 0:
            for step2 in range(policy_update_iters):
                policy_update_info = policy.update(q_model, s, a, r, sp, done, gamma=gamma, bc=bc)

        if step % log_freq == 0:
            log['step'].append(step)

            info = {}
            info.update(update_info)
            info.update(secondary_update_info)
            info.update(policy_update_info)
            for k, v in info.items():
                if k not in log:
                    log[k] = []
                log[k].append(v)

            if not bc:
                td_error = compute_TD_error(env, get_q_table(q_model), policy=policy, gamma=gamma)
                log['TD-error'].append(td_error)
                print(f"Step {step}: TD-error = {log['TD-error'][-1]:.4f}", end='')
                ref_q_table = compute_opt_q(env, gamma=gamma, policy=policy)
                q_error = compute_Q_error(env, dataset, ref_q_table, get_q_table(q_model))
                log['Q-error'].append(q_error)
                print(f", Q-error = {log['Q-error'][-1]:.4f}", end='')
                if hasattr(q_model, 'g_model'):
                    state, action = np.meshgrid(np.arange(num_states), np.arange(num_actions), indexing="ij")
                    state = torch.tensor(state.flatten(), dtype=torch.long, device=device)
                    action = torch.tensor(action.flatten(), dtype=torch.long, device=device)
                    g_table = q_model.g_model(state, action).detach().cpu().numpy()
                    log['pop_obj'].append((data_dist * np.exp(2 * g_table)).sum())

            if policy_update_info is not None and len(policy_update_info) > 0:
                print('')
                print(f"Policy update info: ", end='')
                for k, v in policy_update_info.items():
                    print(f"{k} = {policy_update_info[k]:.6f}", end=', ')

            if secondary_update_info is not None and len(secondary_update_info) > 0:
                print('')
                print(f"Secondary update info: ", end='')
                for k, v in secondary_update_info.items():
                    print(f"{k} = {secondary_update_info[k]:.6f}", end=', ')

            print('')
            print(env.render_policy(policy))

    return log


def compute_opt_q(
        env,
        gamma=0.99,
        epsilon=0.0,
        atol=1e-6,
        policy=None
):
    P = env.get_transition_matrix()
    R = env.get_reward_matrix()
    q_table = np.zeros((env.observation_space.n, env.action_space.n))
    q_table_prime = np.zeros_like(q_table)

    if policy is not None:
        policy_table = policy.dist(np.arange(num_states))

    while True:
        q_table_prime *= 0
        for s in range(env.observation_space.n):
            for a in range(env.action_space.n):
                for sp in range(env.observation_space.n):
                    if policy is None:
                        v_next = ((1 - epsilon) * q_table[sp].max() + epsilon * q_table[sp].mean())
                    else:
                        v_next = (policy_table[sp] * q_table[sp]).sum()
                    q_table_prime[s, a] += P[s, a, sp] * (R[s, a] + gamma * v_next)
        if np.allclose(q_table, q_table_prime, atol=atol):
            break
        q_table = q_table_prime.copy()
    return q_table


def train_on_policy(k=63, device=None, iters=int(2e5)):
    # np.random.seed(seed)
    # torch.manual_seed(seed + 123)

    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    linear_q = LinearQModel(env.observation_space, env.action_space, Phi_torch, learning_rate=2e-2)
    policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-2,
                          use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4)
    log = train(linear_q, policy, iters=iters, batch_size=32, progress_bar=True, log_freq=int(1e3),
                dataset_update_freq=int(1e2), policy_update_freq=1)

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


def train_vanilla(dataset, k=63, device=None, iters=int(2e5)):
    # np.random.seed(seed)
    # torch.manual_seed(seed + 123)

    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    linear_q = LinearQModel(env.observation_space, env.action_space, Phi_torch, learning_rate=2e-2)
    policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-2,
                          use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4)
    log = train(linear_q, policy, dataset, iters=iters, batch_size=32, progress_bar=True,
                log_freq=int(1e3), policy_update_freq=1)

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


class LinearExactImportanceSamplingQModel(LinearQModel):
    def __init__(self, observation_space, action_space, Phi, policy, sampling_dist, learning_rate=1e-1):
        super().__init__(observation_space, action_space, Phi, learning_rate=learning_rate)
        self.policy = policy
        self.sampling_dist = sampling_dist
        if isinstance(self.sampling_dist, np.ndarray):
            self.sampling_dist = torch.tensor(self.sampling_dist, dtype=torch.float32, device=self.Phi.device)

    def update_reweight(self):
        policy_distribution = torch.tensor(get_policy_distribution(env, self.policy), dtype=torch.float32,
                                           device=self.Phi.device)
        self.reweight = torch.clip(policy_distribution / (self.sampling_dist + 1e-6), 1e-2, 1e2)


def train_is(dataset, k=63, device=None, iters=int(2e5)):
    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-2,
                          use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4)
    is_linear_q = LinearExactImportanceSamplingQModel(env.observation_space, env.action_space, Phi_torch, policy,
                                                      dataset_occupancy(env, dataset), learning_rate=1e-2)
    log = train(is_linear_q, policy, dataset, iters=iters, batch_size=32, progress_bar=True,
                log_freq=int(1e3), policy_update_freq=1, secondary_update_freq=int(5e1))

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


def train_pop(dataset, k=63, device=None, iters=int(2e5)):
    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    policy = POPLinearPolicy(
        env.observation_space, env.action_space, Phi_torch, lr=1e-4,
        use_automatic_entropy_tuning=True, target_entropy=0.5, alpha_lr=1e-4,
        use_automatic_kl_tuning=False, beta_multiplier=2e-1,
    )
    g_model = LinearGModel(env.observation_space, env.action_space, Phi_torch, lr=1e-4)
    pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(),
                 rank=4, q_lr=1e-3, dual_lr=1e-3, pop_margin=-1e-3)
    log = train(pop_q, policy, dataset, iters=iters, batch_size=32, progress_bar=True, log_freq=int(1e3),
                policy_update_freq=1, secondary_update_freq=1, secondary_update_iters=1)

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


def train_bc(dataset, k=63, device=None, iters=int(2e5)):
    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-3, alpha_multiplier=0.0)
    log = train(None, policy, dataset, iters=iters, batch_size=32, progress_bar=True, log_freq=int(1e3), bc=True)

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


def train_cql(dataset, k=63, device=None, iters=int(2e5)):
    Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))
    Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)
    Phi_torch = torch.tensor(Phi, device=device, dtype=torch.float32)

    linear_q = CQL(env.observation_space, env.action_space, Phi_torch, learning_rate=1e-3, alpha_prime=5e0)
    policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-3,
                          use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-3)
    log = train(linear_q, policy, dataset, iters=iters, batch_size=32, progress_bar=True,
                log_freq=int(1e3))

    log['policy_return'] = compute_performance(performance_env, policy)

    return log


def merge_datasets(dataset1, dataset2, p=0.5):
    n = min(len(dataset1.s), len(dataset2.s))
    idx1 = np.random.choice(len(dataset1.s), size=int(n * (1 - p)), replace=False)
    idx2 = np.random.choice(len(dataset2.s), size=int(n * p), replace=False)

    return FrozenLakeDataset(
        s=np.concatenate([dataset1.s[idx1], dataset2.s[idx2]]),
        a=np.concatenate([dataset1.a[idx1], dataset2.a[idx2]]),
        r=np.concatenate([dataset1.r[idx1], dataset2.r[idx2]]),
        sp=np.concatenate([dataset1.sp[idx1], dataset2.sp[idx2]]),
        done=np.concatenate([dataset1.done[idx1], dataset2.done[idx2]])
    )


FLAGS_DEF = define_flags_with_default(
    alg='pop',
    proportion_opt_dataset=0.5,
    num_training_iters=2e5,
    dataset_size=1e6,
    feature_dim=63,

    dataset_epsilon=0.5,

    seed=0,
    output_dir='small_scale/results'
)


def main(argv):
    FLAGS = absl.flags.FLAGS
    variant = get_user_flags(FLAGS, FLAGS_DEF)

    alg = FLAGS.alg
    proportion_opt_dataset = float(FLAGS.proportion_opt_dataset)
    num_training_iters = int(FLAGS.num_training_iters)
    dataset_size = int(FLAGS.dataset_size)
    feature_dim = int(FLAGS.feature_dim)

    dataset_epsilon = float(FLAGS.dataset_epsilon)

    seed = int(FLAGS.seed)
    np.random.seed(seed)
    torch.manual_seed(seed + 123)

    # torch.multiprocessing.set_start_method('spawn')

    # Generate Data distribution
    data_policy = frozen_lake_policy_from_string(
        [
            "↓←←←",
            "↓↑↑↑",
            "↓→→↓",
            "→↑↑↑",
        ],
        epsilon=dataset_epsilon,
    )

    opt_policy = env.shortest_path_policy()

    sub_opt_dataset = generate_dataset(env, data_policy, n=dataset_size, progress_bar=True)

    opt_dataset = generate_dataset(env,
                                   EpsilonGreedyPolicy(opt_policy, env.action_space, epsilon=dataset_epsilon),
                                   n=dataset_size, progress_bar=True)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if alg == 'on_policy':
        log = train_on_policy(k=feature_dim, device=device, iters=num_training_iters)

        save_file = f'{alg}_seed_{seed}.pkl'
    elif alg in ['vanilla', 'is', 'pop', 'bc', 'cql']:
        dataset = merge_datasets(opt_dataset, sub_opt_dataset, p=proportion_opt_dataset)
        if alg == 'vanilla':
            log = train_vanilla(dataset, k=feature_dim, device=device, iters=num_training_iters)
        elif alg == 'is':
            log = train_is(dataset, k=feature_dim, device=device, iters=num_training_iters)
        elif alg == 'pop':
            log = train_pop(dataset, k=feature_dim, device=device, iters=num_training_iters)
        elif alg == 'bc':
            log = train_bc(dataset, k=feature_dim, device=device, iters=num_training_iters)
        elif alg == 'cql':
            log = train_cql(dataset, k=feature_dim, device=device, iters=num_training_iters)

        save_file = f'{alg}_p_{proportion_opt_dataset}_seed_{seed}.pkl'
    else:
        raise ValueError(f'Unknown algorithm {alg}')

    save_data = {
        'log': log,
        'variant': variant,
    }

    # Save the results using pickle
    os.makedirs(FLAGS.output_dir, exist_ok=True)
    with open(os.path.join(FLAGS.output_dir, save_file), 'wb') as f:
        pkl.dump(save_data, f)


if __name__ == '__main__':
    absl.app.run(main)
