#!/usr/bin/env python
from numpy.random import seed, triangular
import torch
import torch.nn as nn
from torch.nn import Sequential
from torch.optim import Adam
import torch.nn.functional as F
torch.set_default_tensor_type(torch.FloatTensor)
from torch.utils.tensorboard import SummaryWriter

import numpy as np 
import sys, copy, argparse, itertools
import ipdb as pdb
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
from matplotlib.pyplot import savefig
from matplotlib.backends.backend_pdf import PdfPages

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from envs.double_integrator import DoubleIntegrator, VisualizeDI
from envs.dubins_car import DubinsCar

parser = argparse.ArgumentParser(description='')
parser.add_argument('--tau', type=float, default=0.1, metavar='G',
                    help='discount factor (default: 0.98)')
parser.add_argument('--gamma_rate', type=float, default=2.2, metavar='G',
                    help='discount factor (default: 0.98)')
parser.add_argument('--seed', type=int, default=None, metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--render', type=bool, default=False,
                    help='save image')
parser.add_argument('--env_name', type=str, default='DoubleIntegrator', metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--exp_name', type=str, default='DQN', metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--w_bl', type=int, default=0, metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--double', type=int, default=0, metavar='N',
                    help='random seed (default: 42)')
parser.add_argument('--alpha', type=float, default=0.05, metavar='G',
                    help='')
args = parser.parse_args()


def mlp(sizes, activation=nn.ReLU, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class QNetwork():
    # Inputs: (BS, n_state+n_ctrl)
    # Output: V, (BS, 1)
    def __init__(self, env, env_name, exp_name = 'MLP', hiddens=[64, 64, 32], lr=0.001, iters_save_model=1e6, iters_update_target=1e6):
        self.env = env
        self.env_name = env_name
        self.hiddens = hiddens
        self.lr = lr
        self.iters_save_model = iters_save_model
        self.iters_update_target = iters_update_target
        self.model = mlp([self.env.n_obs + self.env.n_ctrl]+hiddens+[1])
        #self.params = itertools.chain(self.model.parameters())
        self.optimizer = Adam(self.model.parameters(), lr = self.lr)
        self.exp_name = exp_name
        
    def forward(self, input):
        ## eval only used to match I/O
        return self.model.forward(input)
        
    def update(self, X, y):
        self.optimizer.zero_grad()
        y_pred = self.model.forward(X)
        loss = F.mse_loss(y_pred, y)
        loss.backward()
        self.optimizer.step()
        return loss.detach()

    def save_model_weights(self, iters):
        torch.save(self.model.state_dict(), f'{self.env_name}/ckp/{self.exp_name}_{iters}')
       
    def load_model_weights(self,iters):
		# Helper funciton to load current model weights when updating target network.
        self.model.load_state_dict(torch.load(f'{self.env_name}/ckp/{self.exp_name}_{iters}'))
        

class DoubleQNetwork(QNetwork):
    # Inputs: (BS, n_state+n_ctrl)
    # Output: V, (BS, 1)
    def __init__(self, env, env_name, exp_name = 'DoubleQ', hiddens=[64, 64, 32], lr=0.001, iters_save_model=1e6, iters_update_target=1e6):
        super().__init__(env, env_name, exp_name = exp_name, hiddens=hiddens, lr=lr, iters_save_model=iters_save_model, iters_update_target=iters_update_target)
        ## self.model is inherited.
        self.model2 = mlp([self.env.n_obs + self.env.n_ctrl]+hiddens+[1])
        params = itertools.chain(self.model.parameters(), self.model2.parameters())
        self.optimizer = Adam(params, lr = self.lr)
        self.exp_name = exp_name
    
    def forward(self, input):
        q1_pred = self.model(input)
        q2_pred = self.model2(input)
        return torch.maximum(q1_pred, q2_pred)
    
    def update(self, X, q_target):
        self.optimizer.zero_grad()
        q1_pred = self.model.forward(X)
        q2_pred = self.model2.forward(X)
        loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
        loss.backward()
        self.optimizer.step()
        return loss.detach()

class Replay_Memory():
    def __init__(self, memory_size=25000, burn_in=5000):
        self.memory_size = memory_size
        self.burn_in = burn_in
        # the memory is as a list of transitions (S,A,R,S,D). 
        self.storage = []

    def sample_batch(self, batch_size=32):
		# This function returns a batch of randomly sampled transitions - i.e. state, action, cost, next state, terminal flag tuples.
        rand_idx = np.random.choice(len(self.storage), batch_size)
        return [self.storage[i] for i in rand_idx]

    def append(self, transition):
		# appends transition to the memory. 	
        self.storage.append(transition)
        # only keeps the latest memory_size transitions
        if len(self.storage) > self.memory_size:
            self.storage = self.storage[-self.memory_size:]

class DQN_Agent():
    def __init__(self, env, env_name = 'DoubleIntegrator',
                            exp_name='MLP',
                            policy = 'safe', 
                            tau = 0.1, 
                            render = True,
                            umode = 'max',
                            w_bl = False,
                            double = False):
		# Create the environment
        self.env     = env
        self.env_name = env_name
        self.umode = umode
        # Set parameters here
        self.epsilon = 0.5
        #self.eval_epsilon = 0.05
        #self.epsilon_decay = 4.5e-6
        self.num_steps = 25001 if env_name == 'DoubleIntegrator' else 50001
        ## New!
        self.gamma_schedule = 1-np.logspace(1, args.gamma_rate, self.num_steps//100 + 1, base = 0.15)
        self.gamma = self.gamma_schedule[0]
        ## update iterations
        # By setting to 1e6, de facto turns off hard updates and model saving
        self.iters_update_target = 1e6 # 500
        self.iters_save_model = self.num_steps-1
        
        self.batch_size = 64
        self.num_updates = 1
        
        self.exp_name = exp_name
        self.tau = tau
        self.render = render
        self.with_baseline = w_bl
        self.double = double
        
        lr = 0.002 if self.with_baseline else 0.001
        hiddens = [16, 16] if env_name == 'DoubleIntegrator' else [64, 64, 32]
        
        if double:
            self.current_network = DoubleQNetwork(env, env_name = env_name,
                                            lr = lr, hiddens = hiddens,
                                            iters_save_model=self.iters_save_model,
                                            iters_update_target=self.iters_update_target,
                                            exp_name = self.exp_name)
        else:
            self.current_network = QNetwork(env, env_name = env_name,
                                        lr = lr, hiddens = hiddens,
                                        iters_save_model=self.iters_save_model,
                                        iters_update_target=self.iters_update_target, 
                                        exp_name = self.exp_name)
            
        #pdb.set_trace()
        self.target_network = copy.deepcopy(self.current_network)
        self.target_network.iters_save_model = 1e6 ## Do not Save
   
        self.memory = Replay_Memory(memory_size = self.num_steps+5000)
        # use burn_in method to initialize memory
        self.burn_in_memory()

        self.policy = policy
        self.target_noise = 0.05
        self.noise_clip = 0.1
    
        ## For performance evaluation
        if self.env_name == 'DoubleIntegrator':
            self.gt = np.load('envs/V_di.npy')
        elif self.env_name == 'DubinsCar':
            self.gt = np.load('envs/V_dubins.npy')
            self.gt = self.gt[::6, ::6, ::6]
        else:
            print('NOT Implemented')
        self.tb_logger = SummaryWriter(f"{self.env_name}/runs/{self.exp_name}")
    
    def train(self):
        results = []
        state = self.env.reset()
        is_terminal = False

        for i in range(self.num_steps):
            costs = 0

            if (i % self.iters_save_model == 0) & (i>0):
                self.current_network.save_model_weights(i)
            
            # hard updates
            #if i % self.iters_update_target == 0:
            #    self.target_network.model.load_state_dict(self.current_network.model.state_dict())
            

            with torch.no_grad():
                for p, p_targ in zip(self.current_network.model.parameters(), self.target_network.model.parameters()):
                    # NB: We use an in-place operations "mul_", "add_" to update target params
                    p_targ.data.mul_(1-self.tau)
                    p_targ.data.add_(self.tau * p.data)
                
                if self.double:
                    for p, p_targ in zip(self.current_network.model2.parameters(), self.target_network.model2.parameters()):
                        p_targ.data.mul_(1-self.tau)
                        p_targ.data.add_(self.tau * p.data)
                    
            if (i%100==0):
                self.gamma = self.gamma_schedule[i//100]
                self.gamma = min(0.99, self.gamma)
            
            ## The performance policy takes random action
            action = np.random.uniform(self.env.u_min, self.env.u_max)
            #action = self.env.safe_policy(state.reshape(1, -1))
            #action = action.item()
            
            nextstate, cost, is_terminal = self.env.step(action)
            costs += cost
            self.memory.append((state,action,cost,nextstate,is_terminal))
            
            if is_terminal:
                nextstate = self.env.reset()
            state = nextstate

            for _ in range(self.num_updates):    
                batch = self.memory.sample_batch(self.batch_size)
                X_train, y_train = self._get_training_samples(batch)
                self.current_network.update(X_train, y_train)
            
            if i % 1000 == 0:
                print(i)
                self.eval_performance(i)
            
    def _get_training_samples(self, batch):
        s1, a1, c1, s2, done = list(zip(*batch))
        s1, a1, s2, done = np.array(s1), np.array(a1), np.array(s2), np.array(done)
        
        lx = self.env.lx(s1)
        lx = torch.tensor(lx.reshape(-1, 1)).float()
        
        ## Use fixed safe policy:
        a2 = self.env.safe_policy(s2)
        with torch.no_grad():
            next_q_values = self.target_network.forward(torch.tensor(np.concatenate([s2, a2], axis = -1)).float())
        
        if self.with_baseline:
            lx2 = self.env.lx(s2)
            lx2 = torch.tensor(lx2.reshape(-1, 1)).float()
            
            target = self.gamma * torch.minimum(torch.zeros_like(lx), next_q_values + lx2 - lx)
            target[done] = 0
        else:
            target = (1-self.gamma) * lx + self.gamma * torch.minimum(lx, next_q_values)
            target[done] = lx[done] ## if Done target = lx
            
        if self.env.n_ctrl == 1:
            a1 = np.expand_dims(a1, -1)
        return torch.tensor(np.concatenate([s1, a1], axis = -1)).float(), target

    def eval_performance(self, no_eps):
        if self.env_name == 'DoubleIntegrator':
            v_axis = np.linspace(-2.0, 2.0, 21)
            x_axis = np.linspace(-1.0, 1.0, 21)
            value = np.zeros((21, 21))
            for i, v in enumerate(v_axis):
                for j, x in enumerate(x_axis):
                    s = np.array([x, v]).reshape(1, -1)
                    a = self.env.safe_policy(s)
                    q_value = self.current_network.forward(torch.tensor(np.concatenate([s, a], axis = -1)).float())
                    value[i, j] = q_value.item()
                    if self.with_baseline:
                        value[i, j] += self.env.lx(s)
            if self.render:
                fig, ax = plt.subplots(figsize = (5, 5))
                VisualizeDI(value, fig, ax, saveName = f"{self.env_name}/plots/{self.exp_name}_{no_eps}.pdf")
                
        elif self.env_name == 'DubinsCar':
            x_axis = np.linspace(-3.0, 3.0, 11)
            y_axis = np.linspace(-3.0, 3.0, 11)
            theta_axis = np.linspace(-np.pi, np.pi, 11)
            value = np.zeros((11, 11, 11))
                    
            for i, x in enumerate(x_axis):
                for j, y in enumerate(y_axis):
                    for k, theta in enumerate(theta_axis):
                        s = np.array([x, y, np.sin(theta), np.cos(theta)]).reshape(1, -1)
                        a = self.env.safe_policy(s)
                        q_value = self.current_network.forward(torch.tensor(np.concatenate([s, a], axis = -1)).float())
                        value[i, j, k] = q_value.item()
                        if self.with_baseline:
                            value[i, j, k] += self.env.lx(s)
        if self.render:
            np.save(f"{self.env_name}/plots/{self.exp_name}_{no_eps}.npy", value)
        self.compute_metrics(value, no_eps)
        
    def compute_metrics(self, value, no_eps):
        ## Compute Metrics
        accuracy = np.mean((value<0) == (self.gt<0))
        self.tb_logger.add_scalar('accuracy', accuracy, no_eps)
        print(accuracy)
        tn, fp, fn, tp = confusion_matrix((self.gt<0).ravel(), (value<0).ravel()).ravel()
        self.tb_logger.add_scalar('cm/TN', tn, no_eps)
        self.tb_logger.add_scalar('cm/FP', fp, no_eps)
        self.tb_logger.add_scalar('cm/FN', fn, no_eps)
        self.tb_logger.add_scalar('cm/TP', tp, no_eps)
        
                
    def burn_in_memory(self):
		# Initialize your replay memory with a burn_in number of episodes / transitions. 
        total_steps = 0
        num_episodes = 0
        while total_steps <= self.memory.burn_in:
            state = self.env.reset()
            is_terminal = False
            num_episodes += 1
    
            while (not is_terminal):
                action = np.random.uniform(-1, 1)
                nextstate, cost, is_terminal = self.env.step(action)
                self.memory.append((state, action, cost, nextstate, is_terminal))
                # print((state, action))
                state = nextstate
                x = state[..., 0]
                total_steps += 1

def main(args):
    if args.env_name == "DoubleIntegrator":
        env = DoubleIntegrator()
        ## Goal: Stay within [-1, 1] on x-axis
        umode = 'max'
    elif args.env_name == "DubinsCar":
        env = DubinsCar()
        ## Goal: Reach the circle regardless of yaw
        umode = 'min'
    env.reset()

    if args.seed == None:
        seed = np.random.randint(255)
    else:
        seed = args.seed
        
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    exp_name = 'DDQN' if args.double else 'DQN'
    save_name =  f"{exp_name}_{args.w_bl}_tau={args.tau}_gamma={args.gamma_rate}_seed={seed}"
    agent = DQN_Agent(env=env,
                env_name = args.env_name,
                umode = umode,
                exp_name = save_name,
                tau = args.tau,
                render = args.render,
                w_bl = args.w_bl,
                double = args.double
                )
    agent.train()

if __name__ == '__main__':
    main(args)

