from models import Value

import matplotlib.pyplot as plt
from argparse import Namespace
from collections import deque
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import ctypes
import torch
import sys
import os


class Env:
    def __init__(self) -> None:
        self.obs_dim = 1
        self.obs_space = [0, 1]
        self.action_dim = 1
        self.action_space = [0, 1]
        self.mean1, self.mean2 = 0.0, 1.0
        self.std1, self.std2 = 1.0, 2.0
        self.n_envs = 10
    
    def reset(self):
        self.states = np.zeros((self.n_envs, 1))
        return self.states.copy()
    
    def step(self, actions):
        self.states = np.mod(self.states + actions, 2)
        rewards = 0.01*(np.random.randn(self.n_envs)*(self.states.reshape(-1) + 2.0) + self.states.reshape(-1) - 0.5)
        return self.states.copy(), rewards
    


env = Env()

args = Namespace()
args.max_len = 1000
args.n_envs = env.n_envs
args.obs_dim = env.obs_dim
args.action_dim = env.action_dim
args.hidden_dim = 64
args.activation = 'ReLU'
args.n_critics = 2
args.n_quantiles = 25
args.n_target_quantiles = 50
args.discount_factor = 0.99
args.lr = 3e-4
args.n_epochs = 20
args.gae_coeff = None
args.device = torch.device('cuda:0')
args.len_replay_buffer = 10000




##################
# for ground truth
##################

scores_list = []
for _ in tqdm(range(1000)):
    states = env.reset()
    scores = np.zeros(args.n_envs)
    max_len = 1000
    for step_idx in range(max_len):
        actions = np.random.randint(0, 2, (args.n_envs, 1))
        next_states, rewards = env.step(actions)
        scores += rewards*(args.discount_factor**step_idx)
    scores_list += list(scores)

n_quantiles = args.n_quantiles*args.n_critics
cum_prob_np = (np.arange(n_quantiles) + 0.5)/n_quantiles
df = pd.DataFrame({"data": scores_list})
gt_quantiles = np.array([df.quantile(q=q)[0] for q in cum_prob_np])




lib = ctypes.cdll.LoadLibrary('cpp_modules/main.so')
lib.projection.restype = None

def ctype_arr_convert(arr):
    arr = np.ravel(arr)
    return (ctypes.c_double * len(arr))(*arr)

def _projection(quantiles1, weight1, quantiles2, weight2):
    n_quantiles1 = len(quantiles1)
    n_quantiles2 = len(quantiles2)
    n_quantiles3 = args.n_target_quantiles
    assert n_quantiles1 == args.n_quantiles*args.n_critics

    new_quantiles = np.zeros(n_quantiles3)
    cpp_quantiles1 = ctype_arr_convert(quantiles1)
    cpp_quantiles2 = ctype_arr_convert(quantiles2)
    cpp_new_quantiles = ctype_arr_convert(new_quantiles)

    lib.projection.argtypes = [
        ctypes.c_int, ctypes.c_double, ctypes.POINTER(ctypes.c_double*n_quantiles1), ctypes.c_int, ctypes.c_double, 
        ctypes.POINTER(ctypes.c_double*n_quantiles2), ctypes.c_int, ctypes.POINTER(ctypes.c_double*n_quantiles3)
    ]
    lib.projection(n_quantiles1, weight1, cpp_quantiles1, n_quantiles2, weight2, cpp_quantiles2, n_quantiles3, cpp_new_quantiles)
    new_quantiles = np.array(cpp_new_quantiles)
    return new_quantiles

def _getQuantileTargets(rewards, dones, next_quantiles):
    target_quantiles = np.zeros((next_quantiles.shape[0], args.n_target_quantiles))
    gae_target = rewards[-1] + args.discount_factor*next_quantiles[-1]
    gae_weight = args.gae_coeff
    for t in reversed(range(len(target_quantiles))):
        target = rewards[t] + args.discount_factor*next_quantiles[t]
        target = _projection(target, 1.0 - args.gae_coeff, gae_target, gae_weight)
        target_quantiles[t, :] = target[:]
        if t != 0:
            if args.gae_coeff != 1.0:
                gae_weight = args.gae_coeff*(1.0 - dones[t-1])*(1.0 - args.gae_coeff + gae_weight)
            gae_target = rewards[t-1] + args.discount_factor*target
    return target_quantiles




total_list = []

for gae_coeff in [0.0, 0.5, 0.9, 1.0]:
    print(f"GAE's coefficient: {gae_coeff}")
    args.gae_coeff = gae_coeff
    value = Value(args).to(args.device)
    value_optimizer = torch.optim.Adam(value.parameters(), lr=args.lr)
    replay_buffer = [deque(maxlen=int(args.len_replay_buffer/args.n_envs)) for _ in range(args.n_envs)]

    w_dists = []

    for _ in tqdm(range(30)):
        states = env.reset()
        scores = np.zeros(args.n_envs)
        for step_idx in range(args.max_len):
            actions = np.random.randint(0, 2, (args.n_envs, 1))
            next_states, rewards = env.step(actions)
            dones = np.ones(args.n_envs) if step_idx == args.max_len - 1 else np.zeros(args.n_envs)
            for env_idx in range(args.n_envs):
                replay_buffer[env_idx].append([states[env_idx], actions[env_idx], rewards[env_idx], dones[env_idx], next_states[env_idx]])
            states[:] = next_states
            scores += rewards*(args.discount_factor**step_idx)

        states_list = []
        actions_list = []
        targets_list = []
        for env_idx in range(args.n_envs):
            env_trajs = list(replay_buffer[env_idx])
            states = np.array([traj[0] for traj in env_trajs])
            actions = np.array([traj[1] for traj in env_trajs])
            rewards = np.array([traj[2] for traj in env_trajs])
            dones = np.array([traj[3] for traj in env_trajs])
            next_states = np.array([traj[4] for traj in env_trajs])

            states_tensor = torch.tensor(states, device=args.device, dtype=torch.float32)
            actions_tensor = torch.tensor(actions, device=args.device, dtype=torch.float32)
            next_states_tensor = torch.tensor(next_states, device=args.device, dtype=torch.float32)
            
            next_pi = torch.randint(0, 2, (states_tensor.shape[0], 1), device=args.device).to(torch.float)
            next_reward_quantiles_tensor = value(next_states_tensor, next_pi).reshape(len(states), -1) # B x NM
            next_reward_quantiles = torch.sort(next_reward_quantiles_tensor, dim=-1)[0].detach().cpu().numpy()
            reward_targets = _getQuantileTargets(rewards, dones, next_reward_quantiles)

            targets_list.append(reward_targets)
            states_list.append(states)
            actions_list.append(actions)

        with torch.no_grad():
            states_tensor = torch.tensor(np.concatenate(states_list, axis=0), device=args.device, dtype=torch.float32)
            actions_tensor = torch.tensor(np.concatenate(actions_list, axis=0), device=args.device, dtype=torch.float32)
            targets_tensor = torch.tensor(np.concatenate(targets_list, axis=0), device=args.device, dtype=torch.float32)
            targets_tensor.unsqueeze_(dim=1) # B x 1 x kN

        # ================== Value Update ================== #
        # calculate cdf
        with torch.no_grad():
            cum_prob = (torch.arange(args.n_quantiles, device=args.device, dtype=torch.float32) + 0.5)/args.n_quantiles
            cum_prob = cum_prob.view(1, 1, -1, 1) # 1 x 1 x M x 1

        for _ in range(args.n_epochs):
            # calculate quantile regression loss for reward
            current_reward_quantiles = value(states_tensor, actions_tensor) # B x N x M
            pairwise_reward_delta = targets_tensor.unsqueeze(-2) - current_reward_quantiles.unsqueeze(-1)
            reward_value_loss = torch.mean(pairwise_reward_delta*(cum_prob - (pairwise_reward_delta.detach() < 0).float()))
            value_optimizer.zero_grad()
            reward_value_loss.backward()
            value_optimizer.step()
        # ================================================== #

        states_tensor = torch.zeros((1, args.obs_dim), device=args.device, dtype=torch.float32)
        actions_tensor = torch.zeros((1, args.action_dim), device=args.device, dtype=torch.float32)
        reward_quantiles = torch.sort(value(states_tensor, actions_tensor).reshape((1, -1)), dim=-1)[0]
        pred_quantiles = reward_quantiles.detach().cpu().numpy()[0]
        w_dist = np.sum(np.abs(pred_quantiles - gt_quantiles))
        w_dists.append(w_dist)

    total_list.append(w_dists)

    ###############
    # for visualize
    ###############

    states_tensor = torch.zeros((1, args.obs_dim), device=args.device, dtype=torch.float32)
    actions_tensor = torch.zeros((1, args.action_dim), device=args.device, dtype=torch.float32)
    reward_quantiles = torch.sort(value(states_tensor, actions_tensor).reshape((1, -1)), dim=-1)[0]
    pred_quantiles = reward_quantiles.detach().cpu().numpy()[0]

    gt_x = []
    pred_x = []
    pred_cdf = []
    gt_cdf = []
    for i in range(len(cum_prob_np)):
        pred_x.append(pred_quantiles[i])
        pred_cdf.append(cum_prob_np[i] - 0.5/n_quantiles)
        pred_x.append(pred_quantiles[i])
        pred_cdf.append(cum_prob_np[i] + 0.5/n_quantiles)

        gt_x.append(gt_quantiles[i])
        gt_cdf.append(cum_prob_np[i] - 0.5/n_quantiles)
        gt_x.append(gt_quantiles[i])
        gt_cdf.append(cum_prob_np[i] + 0.5/n_quantiles)

    axis = plt.subplot(1, 1, 1)
    axis.plot(pred_x, pred_cdf, label="predict")
    axis.plot(gt_x, gt_cdf, label="GT")
    axis.grid()
    plt.legend()
    plt.savefig(f"imgs/quantile_{gae_coeff}.png")
    plt.close()



window = 5
name_list = ['0.0', '0.5', '0.9', '1.0']
for i in range(int(len(total_list[0])/window)):
    for idx in range(len(total_list)):
        start_idx = window*i
        end_idx = (i+1)*window
        temp_list = total_list[idx][start_idx:end_idx]
        print(f"[{name_list[idx]}] {np.mean(temp_list):.3f} ({np.std(temp_list):.3f})")
    print('='*10)