import sys
import os
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))
sys.path.insert(0, project_dir)
from torch.utils.tensorboard import SummaryWriter
import policies
from policies import *
from datagen import *
from rex_quadrotor import RexQuadrotor
from my_envs.cartpole import CartpoleEnv
from envs import PendulumEnv, IntegratorEnv
import ipdb
import qpth.qp_wrapper as mpc
import utils
import math
import numpy as np
import torch
import torch.autograd as autograd


torch.set_default_device('cuda')
np.set_printoptions(precision=4, suppress=True)


def seeding(seed=0):
    np.random.seed(seed)
    torch.manual_seed(seed)


def eval_policy(args, env, policy, gt_trajs):
    policy.eval()
    torch.no_grad()
    traj_sample = sample_trajectory(gt_trajs, args.bsz, args.H, args.T)
    traj_sample = {k: v.to(args.device) for k, v in traj_sample.items()}

    if args.env == "pendulum":
        traj_sample["state"] = utils.unnormalize_states_pendulum(
            traj_sample["state"])
        traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])
    elif args.env == "cartpole1link" or args.env == "cartpole2link":
        traj_sample["state"] = utils.unnormalize_states_cartpole_nlink(
            traj_sample["state"])
        traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])

    gt_obs = traj_sample["obs"]
    noisy_obs = noise_utils.corrupt_observation(
        gt_obs, args.data_noise_type, args.data_noise_std, args.data_noise_mean)
    if args.H == 1:
        obs_in = noisy_obs.squeeze(1)
    else:
        obs_in = noisy_obs
    
    gt_actions = traj_sample["action"]
    gt_states = traj_sample["state"]
    gt_mask = traj_sample["mask"]

    state = env.reset(bsz=args.bsz).float()
    eval_streaming_policy(args, env, policy, state, gt_states, gt_actions, gt_mask)

def eval_vanilla_policy(args, env, policy, state, gt_states, gt_actions, gt_mask):
    # history of size bsz x N x nx
    state_hist = state[:,None,:]
    input_hist = torch.tensor([])

    NRUNS = 200
    for i in range(NRUNS):      
        obs_in = state.clone()
        obs_in = env.state_clip(obs_in)

        policy_out = policy(obs_in, gt_states, gt_actions,
                            gt_mask, qp_solve=args.qp_solve, lastqp_solve=args.lastqp_solve)
        nominal_state_net, nominal_state, nominal_action = policy_out["trajs"][-1]

        u = nominal_action[:, 0, :]
        state = env.dynamics(state.to(torch.float64), u.to(torch.float64)).to(torch.float32)
        state_hist = torch.cat((state_hist, state[:,None,:]), dim=1)
        input_hist = torch.cat((input_hist, u[:,None,:]), dim=1)


def eval_streaming_policy(args, env, policy, state, gt_states, gt_actions, gt_mask):
    # history of size bsz x N x nx
    state_hist = state[:,None,:].cpu()
    input_hist = torch.tensor([]).cpu()
    warm_start = False

    NRUNS = 200
    for i in range(NRUNS):      

        obs_in = state.clone()
        obs_in = env.state_clip(obs_in)

        policy_out = policy(obs_in, gt_states, gt_actions,
                            gt_mask, qp_solve=args.qp_solve, lastqp_solve=args.lastqp_solve, warm_start=warm_start)
        nominal_state_net, nominal_state, nominal_action = policy_out["trajs"][-1]
        
        u = nominal_action[:, 0, :]
        state = env.dynamics(state.to(torch.float64), u.to(torch.float64)).to(torch.float32)
        state = env.state_clip(state)
        state_hist = torch.cat((state_hist, state[:,None,:].cpu()), dim=1)
        input_hist = torch.cat((input_hist, u[:,None,:].cpu()), dim=1)
        warm_start = True
        
    rewards = env.reward(state_hist[:,1:], input_hist)
    if state_hist.isnan().any():
        rewards[rewards.isnan()] = 0
        reward = rewards[torch.logical_not(rewards.isnan())].mean()
        num_nan = rewards.isnan().sum()
        print("return : ", reward, "num_nan : ", num_nan)
    else:
        print("return : ", rewards.sum(dim=-1).mean())

def check_param_sensitivity(args, env, policy, gt_trajs):
    traj_sample = sample_trajectory(gt_trajs, args.bsz, args.H, args.T)
    traj_sample = {k: v.to(args.device) for k, v in traj_sample.items()}

    if args.env == "pendulum":
        traj_sample["state"] = utils.unnormalize_states_pendulum(
            traj_sample["state"])
        traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])
    elif args.env == "cartpole1link" or args.env == "cartpole2link":
        traj_sample["state"] = utils.unnormalize_states_cartpole_nlink(
            traj_sample["state"])
        traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])

    gt_obs = traj_sample["obs"]
    noisy_obs = noise_utils.corrupt_observation(
        gt_obs, args.data_noise_type, args.data_noise_std, args.data_noise_mean)
    if args.H == 1:
        obs_in = noisy_obs.squeeze(1)
    else:
        obs_in = noisy_obs
    
    gt_actions = traj_sample["action"]
    gt_states = traj_sample["state"]
    gt_mask = traj_sample["mask"]
    
    # Compute output for the nominal trajectory
    x_t = gt_states[:, 0, :]
    x_ref = gt_states.clone().requires_grad_(True)
    u_ref = gt_actions.clone()*0
    xu_ref = torch.cat((x_ref, u_ref), dim=2)
    jacs_analytic = compute_analytical_jacobian(x_t, x_ref, u_ref, gt_mask, policy, xu_ref)
    # jacs_fd = compute_finite_diff_jacobian(x_t, x_ref, u_ref, gt_mask, policy, xu_ref)
def compute_analytical_jacobian(x_t, x_ref, u_ref, gt_mask, policy, xu_ref):
    policy.tracking_mpc.reinitialize(x_t, gt_mask)
    states, actions, status = policy.tracking_mpc(x_t, xu_ref, x_ref, u_ref, model_call = None, al_iters=12)
    
    # Compute jacobian
    jacs = []
    for j in range(states.shape[0]):
        jac = []
        for i in range(states[0].reshape(-1).shape[0]):
            state = states[j].reshape(-1)[i]
            grad_i = torch.autograd.grad(state, x_ref, retain_graph=True)[0][j]
            jac.append(grad_i)
        jac = torch.stack(jac, dim=0).reshape(states[0].shape[0], states[0].shape[1], states[0].shape[0], states[0].shape[1])
        jacs.append(jac)
    jacs = torch.stack(jacs, dim=0)
    return jacs

def compute_finite_diff_jacobian(x_t, x_ref, u_ref, gt_mask, policy, xu_ref, eps=1e-2):
    batch_size, horizon, state_dim = x_ref.shape
    jacs = torch.zeros(batch_size, horizon, state_dim, horizon, state_dim, device=x_ref.device)

    # Compute nominal trajectory
    policy.tracking_mpc.reinitialize(x_t, gt_mask)
    states_nominal, _, _ = policy.tracking_mpc(x_t, xu_ref, x_ref, u_ref, model_call=None, al_iters=12)

    for t in range(horizon):
        for i in range(state_dim):
            # Create perturbed x_ref
            x_ref_perturbed = x_ref.clone()
            x_ref_perturbed[:, t, i] += eps

            # Compute perturbed trajectory
            policy.tracking_mpc.reinitialize(x_t, gt_mask)
            states_perturbed_pos, _, _ = policy.tracking_mpc(x_t, xu_ref, x_ref_perturbed, u_ref, model_call=None, al_iters=12)

            x_ref_perturbed = x_ref.clone()
            x_ref_perturbed[:, t, i] -= eps

            # Compute perturbed trajectory
            policy.tracking_mpc.reinitialize(x_t, gt_mask)
            states_perturbed_neg, _, _ = policy.tracking_mpc(x_t, xu_ref, x_ref_perturbed, u_ref, model_call=None, al_iters=12)

            # Compute finite difference
            diff = (states_perturbed_pos - states_perturbed_neg) / 2*eps
            jacs[:, :, :, t, i] = diff

    return jacs
