import time
from policies import *
import policies
from deq_layer_utils import update_scales
import utils

def global_fwd(args, obs_in, gt_obs, gt_states, gt_actions, gt_obs_actions, gt_mask, 
        policy, coeffs, pretrain_done, i, losses_iter_nocoeff, losses_proxy_iter_nocoeff, grad_coeff_update=True):

    start = time.time()
    policy_out = policy(obs_in, gt_states, gt_actions,
                            gt_mask, out_iter=i, qp_solve=args.qp_solve and pretrain_done, lastqp_solve=args.lastqp_solve and pretrain_done)
    end = time.time()

    if (i % 20 == 0) and grad_coeff_update:
        coeffs_est, losses_nocoeff, losses_proxy_nocoeff = policies.compute_grad_coeff(policy, gt_states, gt_actions, gt_mask, policy_out, args.deq, pretrain_done)
        coeffs_est = coeffs_est.view(args.deq_iter, args.num_coeffs_per_iter)
        if args.grad_coeff:
            coeffs = coeffs_est*0.2 + coeffs*0.8
        [losses_iter_nocoeff[k].append(losses_nocoeff[k].item()) for k in range(args.deq_iter)]
        [losses_proxy_iter_nocoeff[k].append(losses_proxy_nocoeff[k].item()) for k in range(args.deq_iter)]
    
    if args.scaled_output:
        update_scales(policy, policy_out["trajs"], gt_states, policy_out["init_states"], gamma=0.95)

    loss_dict = policies.compute_loss(policy, gt_states, gt_actions, gt_obs, gt_mask, policy_out, args.deq, pretrain_done, coeffs)
    time_diff = end - start
    return time_diff, loss_dict, policy_out, coeffs


def streaming_fwd(args, obs_in, gt_obs, gt_states, gt_actions, gt_obs_actions, gt_mask,
        policy, coeffs, pretrain_done, i, losses_iter_nocoeff, losses_proxy_iter_nocoeff, grad_coeff_update=False):
    start = time.time()

    gt_states_j = gt_states[:, :args.T, :]
    gt_actions_j = gt_actions[:, :args.T, :]
    gt_mask_j = gt_mask[:, :args.T]
    obs_in_j = obs_in
    gt_obs_j = obs_in_j
    
    policy_out = policy(obs_in_j, gt_states_j, gt_actions_j,
                            gt_mask_j, out_iter=i, qp_solve=args.qp_solve and pretrain_done, lastqp_solve=args.lastqp_solve and pretrain_done)
    
    loss_dict = policies.compute_loss(policy, gt_states_j, gt_actions_j, gt_obs_j, gt_mask_j, policy_out, args.deq, pretrain_done, coeffs[:args.deq_iter, :])

    for j in range(args.streaming_steps):
        obs_in = gt_states[:, j+1, :]
        gt_states_j = gt_states[:, j+1:j+args.T+1, :]
        gt_actions_j = gt_actions[:, j+1:j+args.T+1, :]
        gt_mask_j = gt_mask[:, j+1:j+args.T+1]
        gt_obs_j = obs_in
        policy_out_j = policy(obs_in, gt_states_j, gt_actions_j,
                                gt_mask_j, out_iter=i, qp_solve=args.qp_solve and pretrain_done, lastqp_solve=args.lastqp_solve and pretrain_done, warm_start=True)
        
        policy_out = policy_out_update(policy_out, policy_out_j)    
        loss_dict_j = policies.compute_loss(policy, gt_states_j, gt_actions_j, gt_obs_j, gt_mask_j, policy_out_j, args.deq, pretrain_done, coeffs[args.deq_iter+j*args.str_al_iter:args.deq_iter + (j+1)*args.str_al_iter, :])
        
        loss_dict = update_loss_dict(loss_dict, loss_dict_j)
      
    
    end = time.time()
    if args.scaled_output:
        update_scales(policy, policy_out["trajs"], gt_states, policy_out["init_states"], gamma=0.95)

    time_diff = end - start
    return time_diff, loss_dict, policy_out, coeffs

def policy_out_update(policy_out, policy_out_j):
    policy_out["trajs"] = policy_out["trajs"] + policy_out_j["trajs"]
    if "deq_stats" in policy_out:
        policy_out["deq_stats"]["fwd_err"] = policy_out["deq_stats"]["fwd_err"] + policy_out_j["deq_stats"]["fwd_err"]
        policy_out["deq_stats"]["fwd_steps"] = policy_out["deq_stats"]["fwd_steps"] + policy_out_j["deq_stats"]["fwd_steps"]
        policy_out["deq_stats"]["jac_loss"] = policy_out["deq_stats"]["jac_loss"] + policy_out_j["deq_stats"]["jac_loss"]
    if 'q_scaling' in policy_out:
        policy_out["q_scaling"] = policy_out["q_scaling"] + policy_out_j["q_scaling"]
    if 'nominal_x_ests' in policy_out:
        policy_out["nominal_x_ests"] = policy_out["nominal_x_ests"] + policy_out_j["nominal_x_ests"]
    return policy_out

def update_loss_dict(loss_dict, loss_dict_j):
    for key in loss_dict:
        loss_dict[key] = loss_dict[key] + loss_dict_j[key]
    return loss_dict

def validate_policy(args, env, policy, val_samples, coeffs, pretrain_done, total_deq_iter):
    losses = []
    losses_end = []
    time_diffs = []
    dyn_resids = []
    losses_var = []
    losses_iter = [[] for _ in range(total_deq_iter)]

    losses_iter_opt = [[] for _ in range(total_deq_iter)]
    losses_iter_nn = [[] for _ in range(total_deq_iter)]
    losses_iter_base = [[] for _ in range(total_deq_iter)]
    loss_iter_q = [[] for _ in range(total_deq_iter)]
    losses_iter_nocoeff = [[] for _ in range(total_deq_iter)]
    losses_proxy_iter_nocoeff = [[] for _ in range(total_deq_iter)]
    losses_iter_hist = [[] for _ in range(total_deq_iter)]
    deq_stats = {"fwd_err": [[] for _ in range(total_deq_iter)], "fwd_steps": [[] for _ in range(total_deq_iter)]}

    if coeffs is None:
        coeffs = torch.ones((total_deq_iter, args.num_coeffs_per_iter), device=args.device)
    residuals = []
    for i in range(len(val_samples)):
        traj_sample = val_samples[i]
        traj_sample = {k: v.to(args.device) for k, v in traj_sample.items()}

        if args.env == "pendulum":
            traj_sample["state"] = utils.unnormalize_states_pendulum(
                traj_sample["state"])
            traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])
        elif args.env == "cartpole1link" or args.env == "cartpole2link":
            traj_sample["state"] = utils.unnormalize_states_cartpole_nlink(
                traj_sample["state"])
            traj_sample["obs"] = utils.unnormalize_states_pendulum(traj_sample["obs"])
        elif args.env == "FlyingCartpole":
            traj_sample["state"] = utils.unnormalize_states_flyingcartpole(
            traj_sample["state"])
            traj_sample["obs"] = utils.unnormalize_states_flyingcartpole(traj_sample["obs"])
        gt_obs = traj_sample["obs"]
        obs_in = gt_obs.squeeze(1)
        
        gt_actions = traj_sample["action"]
        gt_states = traj_sample["state"]
        gt_mask = traj_sample["mask"]
        gt_obs_actions = traj_sample["obs_action"]
        # ipdb.set_trace()
        if not args.streaming:
            time_diff, loss_dict, policy_out, coeffs = global_fwd(args, obs_in, gt_obs, gt_states, gt_actions,
                                    gt_obs_actions, gt_mask, policy, coeffs, pretrain_done, i, 
                                    losses_iter_nocoeff, losses_proxy_iter_nocoeff, grad_coeff_update=False)
        else:
            time_diff, loss_dict, policy_out, coeffs = streaming_fwd(args, obs_in, gt_obs, gt_states, gt_actions, 
                                    gt_obs_actions, gt_mask, policy, coeffs, pretrain_done, i, 
                                    losses_iter_nocoeff, losses_proxy_iter_nocoeff)
        res = loss_dict["residuals"]
        residuals.append(res)
        loss = loss_dict["loss"]
        loss_end = loss_dict["loss_end"]
        losses_end.append(loss_end.item())
        losses_var.append(loss_dict["losses_var"])
        [losses_iter[k].append(loss_dict["losses_iter"][k]) for k in range(total_deq_iter)]
        [losses_iter_opt[k].append(loss_dict["losses_iter_opt"][k]) for k in range(total_deq_iter)]
        [losses_iter_nn[k].append(loss_dict["losses_iter_nn"][k]) for k in range(total_deq_iter)]
        [losses_iter_base[k].append(loss_dict["losses_iter_base"][k]) for k in range(total_deq_iter)]
        if "deq_stats" in policy_out:
            [deq_stats["fwd_err"][k].append(policy_out["deq_stats"]["fwd_err"][k].item()) for k in range(total_deq_iter)]
            [deq_stats["fwd_steps"][k].append(policy_out["deq_stats"]["fwd_steps"][k].item()) for k in range(total_deq_iter)]
        if 'nominal_x_ests' in policy_out:
            [losses_iter_hist[k].append(loss_dict['losses_x_ests'][k]) for k in range(total_deq_iter)]
        if 'q_scaling' in policy_out:
            [loss_iter_q[k].append(loss_dict['q_scaling'][k]) for k in range(total_deq_iter)]
    residuals = torch.stack(residuals).cpu().detach().numpy()
    print("Losses end mean: ", np.mean(residuals), " median : ", np.median(residuals))
    print("Losses end: ", np.mean(losses_end))
    # Actually useful losses
    # losses, loss_end, losses_opt
    for k in range(total_deq_iter):
        losses_iter_opt[k] = np.mean(losses_iter_opt[k])
        losses_iter_nn[k] = np.mean(losses_iter_nn[k])
    return np.mean(losses_end)/(1 + args.streaming_steps*float(args.streaming)), losses_iter_opt, losses_iter_nn