#!/usr/bin/env python
from numpy.random import seed, triangular
import torch
import torch.nn as nn
from torch.nn import Sequential
from torch.optim import Adam
import torch.nn.functional as F
torch.set_default_tensor_type(torch.FloatTensor)
from torch.utils.tensorboard import SummaryWriter

import numpy as np 
import sys, copy, itertools
import ipdb as pdb
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
from matplotlib.pyplot import savefig
from matplotlib.backends.backend_pdf import PdfPages

from envs.double_integrator import DoubleIntegrator, VisualizeDI
from envs.dubins_car import DubinsCar
from DQN import QNetwork, DQN_Agent

import argparse
parser = argparse.ArgumentParser(description='CQL')
parser.add_argument('--tau', type=float, default=0.1, metavar='G',
                    help='')
parser.add_argument('--alpha', type=float, default=0.05, metavar='G',
                    help='')
parser.add_argument('--seed', type=int, default=None, metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--render', type=bool, default=False,
                    help='save image')
parser.add_argument('--env_name', type=str, default='DoubleIntegrator', metavar='N', help='environment name')
parser.add_argument('--exp_name', type=str, default='CQL', metavar='N',
                    help='random seed (default: 42)')
args = parser.parse_args()


class ConservativeQ(QNetwork):
    # Inputs: (BS, n_state+n_ctrl)
    # Output: V, (BS, 1)
    def __init__(self, env, env_name, alpha, exp_name = 'DoubleQ', hiddens=[64, 64, 32], lr=0.001, iters_save_model=1e6, iters_update_target=1e6):
        super().__init__(env, env_name, exp_name = exp_name, hiddens=hiddens, lr=lr, iters_save_model=iters_save_model, iters_update_target=iters_update_target)
        self.alpha = alpha
        if self.env_name == 'DubinsCar':
            self.sign = 1 ## Control FP for Reach
        elif self.env_name == 'DoubleIntegrator':
            self.sign = -1 ## Control FN for Avoid
            
    def update(self, X, q_target):
        ## X = [s, a]
        s = X[..., :self.env.n_obs]
        a_star = self.env.safe_policy(s.numpy())
        X_current = torch.tensor(np.concatenate([s, a_star], axis = -1)).float()
        
        self.optimizer.zero_grad()
        q_pred = self.model.forward(X)
        loss = 0.5 * F.mse_loss(q_pred, q_target)
        ## Additional term from CQL
        loss += self.sign * self.alpha * torch.mean(self.model(X_current)-self.model(X))
        loss.backward()
        self.optimizer.step()
        return loss.detach()
        
class CQL_Agent(DQN_Agent):
    def __init__(self, env, env_name = 'DoubleIntegrator',
                            exp_name='MLP',
                            policy = 'safe', 
                            tau = 0.1, 
                            render = True,
                            umode = 'max',
                            alpha = 0.1,
                            gamma = 0.99
                            ):
        super().__init__(env, env_name = env_name, exp_name=exp_name, tau=tau,
         render = render, umode = umode, w_bl = False, double = False)
        lr = 2e-4 #0.001
        hiddens = [16, 16] if env_name == 'DoubleIntegrator' else [64, 64, 32]
        self.current_network = ConservativeQ(env, env_name = env_name,
                                        alpha = alpha,
                                        lr = lr, hiddens = hiddens,
                                        iters_save_model=self.iters_save_model,
                                        iters_update_target=self.iters_update_target,
                                        exp_name = self.exp_name)
            
        self.target_network = copy.deepcopy(self.current_network)
        self.target_network.iters_save_model = 1e6 ## Do not Save
        
        ## gamma is regular RL gamma
        self.gamma_schedule =  gamma * np.ones(self.num_steps//100 + 1)
        
    def _get_training_samples(self, batch):
        ## Convention of CSC, c=1 if unsafe else 0
        s1, a1, c1, s2, done = list(zip(*batch))
        s1, a1, c1, s2, done = np.array(s1), np.array(a1), np.array(c1), np.array(s2), np.array(done)
        c1 = torch.tensor(c1.reshape(-1, 1)).float()
        
        ## a2 \sim pi, in this case optimal safety policy
        a2 = self.env.safe_policy(s2)
        with torch.no_grad():
            next_q_values = self.target_network.forward(torch.tensor(np.concatenate([s2, a2], axis = -1)).float())
        target = c1 + self.gamma * next_q_values ## Regular Bellman backup
        
        if self.env.n_ctrl == 1:
            a1 = np.expand_dims(a1, -1)
        input = torch.tensor(np.concatenate([s1, a1], axis = -1)).float()
        return input, target

    def compute_metrics(self, value, no_eps):
        ## NOTE: CSC has opposite sign convention as HJ value
        accuracy = np.mean((value>0) == (self.gt<0))
        self.tb_logger.add_scalar('accuracy', accuracy, no_eps)
        print(accuracy)
        tn, fp, fn, tp = confusion_matrix((self.gt<0).ravel(), (value>0).ravel()).ravel()
        self.tb_logger.add_scalar('cm/TN', tn, no_eps)
        self.tb_logger.add_scalar('cm/FP', fp, no_eps)
        self.tb_logger.add_scalar('cm/FN', fn, no_eps)
        self.tb_logger.add_scalar('cm/TP', tp, no_eps)

def main(args):
    if args.env_name == "DoubleIntegrator":
        env = DoubleIntegrator(margin = 0)
        ## Objective: Stay within [-1, 1] on x-axis
        umode = 'max'
    elif args.env_name == "DubinsCar":
        env = DubinsCar(margin = 0)
        ## Objective: Reach the circle regardless of yaw
        umode = 'min'
    env.reset()

    if args.seed == None:
        seed = np.random.randint(255)
    else:
        seed = args.seed
        
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    exp_name = 'CQL'
    save_name =  f"{exp_name}_tau={args.tau}_alpha={args.alpha}_seed={seed}"
    agent = CQL_Agent(env=env,
                env_name = args.env_name,
                umode = umode,
                exp_name = save_name,
                tau = args.tau,
                alpha = args.alpha,
                render = args.render)
    agent.train()

if __name__ == '__main__':
    main(args)

