import torch
import numpy as np
import scipy
import scipy.stats
import matplotlib.pyplot as plt
from functools import partial
import gymnasium as gym
import miniworld

from miniworld_env import (
    MiniworldOptPolicy,
    MiniworldRandCommit,
    MiniworldRandPolicy,
    MiniworldTransformerController,
    MiniworldEnvVec,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Helper function for converting trajs into a batch
def process_trajs_into_batch(trajs, lnr):
    # Current traj keys are: ['state', 'action', 'rollin_obs', 'pose', 'angle']
    # Batch keys need to be: ['states', 'actions', 'rollin_obs', 'poses', 'angles', 'positions']
    states = []
    actions = []
    poses = []
    angles = []
    # positions = []
    rollin_obs = []
    rollin_poses = []
    rollin_angles = []
    rollin_actions = []
    rollin_rewards = []
    rollin_next_obs = []
    rollin_next_poses = []
    rollin_next_angles = []

    for traj in trajs:
        states.append(traj['state'])
        actions.append(traj['action'])
        poses.append(traj['pose'])
        angles.append(traj['angle'])

        # Process rollin_obs separately
        filepath = traj['rollin_obs']
        images = np.load(filepath)
        images = [lnr.transform(obs) for obs in images]
        images = torch.stack(images).float().to(device)
        rollin_obs.append(images)

        # Process rollin_next_obs separately
        next_filepath = traj['rollin_next_obs']
        next_images = np.load(next_filepath)
        next_images = [lnr.transform(obs) for obs in next_images]
        next_images = torch.stack(next_images).float().to(device)
        rollin_next_obs.append(next_images)

        rollin_poses.append(traj['rollin_poses'])
        rollin_angles.append(traj['rollin_angles'])
        rollin_actions.append(traj['rollin_us'])
        rollin_rewards.append(traj['rollin_rs'])
        rollin_next_poses.append(traj['rollin_next_poses'])
        rollin_next_angles.append(traj['rollin_next_angles'])

    # # Process positions separately
    # positions = np.zeros((len(rollin_obs), rollin_obs[0].shape[0] + 1))
    # positions[:, 0] = 1.0

    batch = {
        'rollin_obs': torch.stack(rollin_obs),
        'rollin_next_obs': torch.stack(rollin_next_obs),
        'rollin_us': torch.tensor(np.array(rollin_actions)).float().to(device),
        'rollin_actions': torch.tensor(np.array(rollin_actions)).float().to(device),
        'rollin_rs': torch.tensor(np.array(rollin_rewards)).float().to(device),
        'rollin_rewards': torch.tensor(np.array(rollin_rewards)).float().to(device),
        'rollin_poses': torch.tensor(np.array(rollin_poses)).float().to(device),
        'rollin_angles': torch.tensor(np.array(rollin_angles)).float().to(device),
        'rollin_next_poses': torch.tensor(np.array(rollin_next_poses)).float().to(device),
        'rollin_next_angles': torch.tensor(np.array(rollin_next_angles)).float().to(device),
        'states': torch.tensor(np.array(states)).float().to(device),
        'actions': torch.tensor(np.array(actions)).float().to(device),
        'poses': torch.tensor(np.array(poses)).float().to(device),
        'angles': torch.tensor(np.array(angles)).float().to(device),
        # 'positions': torch.tensor(np.array(positions)).float().to(device),
    }
    return batch


def deploy_online_vec(vec_env, controller, learner=False, **kwargs):
    Heps = kwargs['Heps']
    H = kwargs['H']
    horizon = kwargs['horizon']
    include_partial_hist = kwargs['include_partial_hist']
    grow_context = kwargs['grow_context']
    n_eval = kwargs['n_eval']
    assert H % horizon == 0

    ctx_rollouts = H // horizon

    num_envs = vec_env.num_envs
    dx = (3, 25, 25)
    du = 4
    rollin_obs = torch.zeros((num_envs, ctx_rollouts, horizon, *dx)).float().to(device)
    rollin_poses = torch.zeros((num_envs, ctx_rollouts, horizon, 2)).float().to(device)
    rollin_angles = torch.zeros((num_envs, ctx_rollouts, horizon, 2)).float().to(device)
    rollin_us = torch.zeros((num_envs, ctx_rollouts, horizon, du)).float().to(device)
    rollin_next_obs = torch.zeros((num_envs, ctx_rollouts, horizon, *dx)).float().to(device)
    rollin_next_poses = torch.zeros((num_envs, ctx_rollouts, horizon, 2)).float().to(device)
    rollin_next_angles = torch.zeros((num_envs, ctx_rollouts, horizon, 2)).float().to(device)
    rollin_rs = torch.zeros((num_envs, ctx_rollouts, horizon, 1)).float().to(device)

    # position = torch.zeros((num_envs, horizon + 1)).float().to(device)
    # position[:, 0] = 1.0

    cum_means = []
    for i in range(ctx_rollouts):
        batch = {
            'rollin_obs': rollin_obs[:, :i].reshape(num_envs, -1, *dx),
            'rollin_poses': rollin_poses[:, :i].reshape(num_envs, -1, 2),
            'rollin_angles': rollin_angles[:, :i].reshape(num_envs, -1, 2),
            'rollin_actions': rollin_us[:, :i].reshape(num_envs, -1, du),
            'rollin_next_obs': rollin_next_obs[:, :i].reshape(num_envs, -1, *dx),
            'rollin_next_poses': rollin_next_poses[:, :i].reshape(num_envs, -1, 2),
            'rollin_next_angles': rollin_next_angles[:, :i].reshape(num_envs, -1, 2),
            'rollin_rewards': rollin_rs[:, :i].reshape(num_envs, -1, 1),
            # 'positions': position[:, :i+1].reshape(num_envs, -1),
        }
        controller.set_batch(batch)
        if controller.save_video:
            controller.filename_template = partial(kwargs['filename'], ep=i)

        (
            xs_lnr,
            poses_lnr,
            angles_lnr,
            us_lnr,
            xps_lnr,
            next_poses_lnr,
            next_angles_lnr,
            rs_lnr,
        ) = vec_env.deploy_eval(controller)
        if learner:
            rollin_obs[:, i] = xs_lnr
            rollin_poses[:, i] = torch.tensor(poses_lnr)
            rollin_angles[:, i] = torch.tensor(angles_lnr)
            rollin_us[:, i] = torch.tensor(us_lnr)
            rollin_next_obs[:, i] = xps_lnr
            rollin_next_poses[:, i] = torch.tensor(next_poses_lnr)
            rollin_next_angles[:, i] = torch.tensor(next_angles_lnr)
        rollin_rs[:, i] = torch.tensor(rs_lnr[:, :, None])

        cum_means.append(np.sum(rs_lnr, axis=-1))

    for h_ep in range(ctx_rollouts, Heps):
        # Reshape the batch as a singular length H = ctx_rollouts * horizon sequence.
        batch = {
            'rollin_obs': rollin_obs.reshape(num_envs, -1, *dx),
            'rollin_poses': rollin_poses.reshape(num_envs, -1, 2),
            'rollin_angles': rollin_angles.reshape(num_envs, -1, 2),
            'rollin_actions': rollin_us.reshape(num_envs, -1, du),
            'rollin_next_obs': rollin_next_obs.reshape(num_envs, -1, *dx),
            'rollin_next_poses': rollin_next_poses.reshape(num_envs, -1, 2),
            'rollin_next_angles': rollin_next_angles.reshape(num_envs, -1, 2),
            'rollin_rewards': rollin_rs.reshape(num_envs, -1, 1),
            # 'positions': position.reshape(num_envs, -1),
        }
        controller.set_batch(batch)
        if controller.save_video:
            controller.filename_template = partial(kwargs['filename'], ep=h_ep)

        (
            xs_lnr,
            poses_lnr,
            angles_lnr,
            us_lnr,
            xps_lnr,
            next_poses_lnr,
            next_angles_lnr,
            rs_lnr,
        ) = vec_env.deploy_eval(
            controller,
            include_partial_hist=include_partial_hist,
            grow_context=grow_context)

        mean = np.sum(rs_lnr, axis=-1)
        cum_means.append(mean)

        # convert to torch
        xs_lnr = xs_lnr.float().to(device)
        poses_lnr = torch.tensor(poses_lnr).float().to(device)
        angles_lnr = torch.tensor(angles_lnr).float().to(device)
        us_lnr = torch.tensor(us_lnr).float().to(device)
        xps_lnr = xps_lnr.float().to(device)
        next_poses_lnr = torch.tensor(next_poses_lnr).float().to(device)
        next_angles_lnr = torch.tensor(next_angles_lnr).float().to(device)
        rs_lnr = torch.tensor(rs_lnr[:, :, None]).float().to(device)

        # Roll in new data by shifting the batch and appending the new data.
        if learner:
            rollin_obs = torch.cat((rollin_obs[:, 1:], xs_lnr[:, None]), dim=1)
            rollin_poses = torch.cat((rollin_poses[:, 1:], poses_lnr[:, None]), dim=1)
            rollin_angles = torch.cat((rollin_angles[:, 1:], angles_lnr[:, None]), dim=1)
            rollin_us = torch.cat((rollin_us[:, 1:], us_lnr[:, None]), dim=1)
            rollin_next_obs = torch.cat((rollin_next_obs[:, 1:], xps_lnr[:, None]), dim=1)
            rollin_next_poses = torch.cat((rollin_next_poses[:, 1:], next_poses_lnr[:, None]), dim=1)
            rollin_next_angles = torch.cat((rollin_next_angles[:, 1:], next_angles_lnr[:, None]), dim=1)

        # Random-commit policy needs updated rollin_rs
        rollin_rs = torch.cat((rollin_rs[:, 1:], rs_lnr[:, None]), dim=1)

    return np.stack(cum_means, axis=1)


def online_vec(eval_trajs, model, **kwargs):
    Heps = kwargs['Heps']
    H = kwargs['H']
    n_eval = kwargs['n_eval']
    horizon = kwargs['horizon']
    filename_template = kwargs['filename']
    assert H % horizon == 0

    all_means_lnr = []

    envs = []

    envname = kwargs['envname']
    if envname.startswith('mini_two_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMulti'
    elif envname.startswith('mini_three_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiThreeBoxes'
    elif envname.startswith('mini_four_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFourBoxes'
    elif envname.startswith('mini_blue'):
        env_name = 'MiniWorld-OneRoomS6FastMultiBlue'
    else:
        raise ValueError("Invalid envname")

    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")
        traj = eval_trajs[i_eval]
        kwargs['traj'] = traj
        kwargs['i_eval'] = i_eval

        env = gym.make(f'{env_name}FixedInit-v0')
        env.set_task(env_id=8000 + i_eval)
        envs.append(env)

    vec_env = MiniworldEnvVec(envs)
    kwargs['n_eval'] = n_eval

    # Learner
    print("Evaluating learner")
    lnr_kwargs = kwargs.copy()
    lnr_filename_template = partial(filename_template.format, controller='lnr')
    lnr_kwargs['filename'] = lnr_filename_template
    lnr_controller = MiniworldTransformerController(
        model, batch_size=n_eval, sample=True, save_video=False, filename_template=lnr_filename_template)
    cum_means_lnr = deploy_online_vec(vec_env, lnr_controller, learner=True, **lnr_kwargs)

    all_means_lnr = np.array(cum_means_lnr)
    means_lnr = np.mean(all_means_lnr, axis=0)
    sems_lnr = scipy.stats.sem(all_means_lnr, axis=0)

    # Optimal policy
    print("Evaluating optimal policy")
    opt_kwargs = kwargs.copy()
    opt_kwargs['Heps'] = 1
    opt_filename_template = partial(filename_template.format, controller='opt')
    opt_kwargs['filename'] = opt_filename_template
    opt_controller = MiniworldOptPolicy(
        vec_env, batch_size=n_eval, save_video=False, filename_template=opt_filename_template)
    cum_means_opt = deploy_online_vec(vec_env, opt_controller, **opt_kwargs)
    cum_means_opt = np.repeat(cum_means_opt, kwargs['Heps'], axis=-1)

    all_means_opt = np.array(cum_means_opt)
    means_opt = np.mean(all_means_opt, axis=0)
    sems_opt = scipy.stats.sem(all_means_opt, axis=0)

    # Random policy
    print("Evaluating random policy")
    rand_controller = MiniworldRandPolicy(vec_env, batch_size=n_eval)
    cum_means_rand = deploy_online_vec(vec_env, rand_controller, **kwargs)

    all_means_rand = np.array(cum_means_rand)
    means_rand = np.mean(all_means_rand, axis=0)
    sems_rand = scipy.stats.sem(all_means_rand, axis=0)


    print("Computing random-commit performance from opt and rand controllers")
    cum_means_rand_commit = cum_means_rand.copy()    # (20, 40)
    for i_eval in range(n_eval):
        found_goal_ts = np.where(cum_means_rand[i_eval] > 0)[0]
        if len(found_goal_ts) > 0:
            found_goal_t = found_goal_ts[0]
            cum_means_rand_commit[i_eval, found_goal_t+1:] = cum_means_opt[i_eval, found_goal_t+1:]

    all_means_rand_commit = np.array(cum_means_rand_commit)
    means_rand_commit = np.mean(all_means_rand_commit, axis=0)
    sems_rand_commit = scipy.stats.sem(all_means_rand_commit, axis=0)

    # plot individual curves
    for i in range(n_eval):
        plt.plot(all_means_lnr[i], color='blue', alpha=0.2)
        plt.plot(all_means_opt[i], color='green', alpha=0.2)
        plt.plot(all_means_rand_commit[i], color='red', alpha=0.2)
        plt.plot(all_means_rand[i], color='orange', alpha=0.2)

    # plot the results with fill between
    plt.plot(means_lnr, color='blue', label='LNR')
    plt.fill_between(np.arange(Heps), means_lnr - sems_lnr, means_lnr + sems_lnr, color='blue', alpha=0.2)

    plt.plot(means_opt, color='green', label='Optimal')
    plt.fill_between(np.arange(Heps), means_opt - sems_opt, means_opt + sems_opt, color='green', alpha=0.2)

    plt.plot(means_rand_commit, color='red', label='Rand Commit')
    plt.fill_between(np.arange(Heps), means_rand_commit - sems_rand_commit, means_rand_commit + sems_rand_commit, color='red', alpha=0.2)

    plt.plot(means_rand, color='orange', label='Rand')
    plt.fill_between(np.arange(Heps), means_rand - sems_rand, means_rand + sems_rand, color='orange', alpha=0.2)

    plt.legend()
    plt.xlabel('t')
    plt.ylabel('Average Reward')
    plt.title(f'Online Evaluation on {n_eval} envs')

    baselines = {
        'lnr': all_means_lnr,
        'opt': all_means_opt,
        'rand_commit': all_means_rand_commit,
        'rand': all_means_rand,
    }
    return baselines