import argparse
import numpy as np
from env import Network, MDP, MDP_hard
from learner_rbmle import RBMLE
from learner_ucb import UCB
from learner_dqn import DQN
import tqdm
import random
import torch
import time
import pickle
import matplotlib.pyplot as plt
np.set_printoptions(precision=3)

if __name__ == '__main__':
    # define arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--S', type=int, help='dimension of state space', default=3)
    parser.add_argument('--A', type=int, help='dimension of action space', default=2) 
    parser.add_argument('--K', type=int, help='dimension of theta', default=3)  
    parser.add_argument('--gamma', type=float, help='discount factor', default=0.9)
    parser.add_argument('--seed', type=int, help='random seed', default=1)
    parser.add_argument('--temperature', type=float, help='temperature of solfmax', default=1e-02)
    parser.add_argument('--T', type=int, help='number of time steps', default=500)
    parser.add_argument('--Episodes', type=int, help='number of episodes', default=10)
    parser.add_argument('--lamdba', type=float, help='regularization parameter for ucb', default=1.0)
    parser.add_argument('--delta', type=float, help='with probability at least 1-delta', default=1.0/500**2)
    parser.add_argument('--alpha', type=str, help='reward bias for rbmle', default="0.5")
    parser.add_argument('--epsilon', type=float, help='epsilon greedy threshold for dqn', default = 0.05)
    parser.add_argument('--eta', type=float, help='exploration parameter for ucb and rbmle', default=1.0)
    parser.add_argument('--lr', type=float, help='learning rate for dqn', default=0.1)
    parser.add_argument('--hidden_size', type=int, help='hidden size for dqn', default=100)
    parser.add_argument('--batch_size', type=int, help='batch size for dqn', default=32)
    parser.add_argument('--learner', type=str, help='learner type', default='RBMLE')
    parser.add_argument('--env', type=str, help='env type', default='normal')
    parser.add_argument('--solving_approach', type=str, help='solving approach of RBMLE, exact or approximated', default='exact')
    args = parser.parse_args()
    
    episode_regret = [] 
    episode_time = []
    # episode_xt = []
    # episode_diff_theta = []
    # define the environment.
    # # define the true parameters.
    np.random.seed(11)
    torch.manual_seed(11)
    random.seed(11)

    theta = np.random.rand(args.K) 
    theta = theta / np.sum(theta)
    func = Network(args.S+args.A,(args.S, args.K), temperature = 0.1) # define the phi function.
    if args.env == 'normal':
        env = MDP(args.S, args.A, func, theta) # define the environment.
    elif args.env == 'hard':
        env = MDP_hard(args.S)
    else:
        raise ValueError(args.env)
    
    for e in tqdm.tqdm(range(args.Episodes)):
        # set the random seed
        if args.seed >= 0:
            np.random.seed(args.seed+e)
            torch.manual_seed(args.seed+e)
            random.seed(args.seed+e)

        state = env.reset()

        # define the learner.
        if args.learner == 'RBMLE':
            learner = RBMLE(args.S, args.A, args.K, env.R, args.gamma, env.func, args.alpha, args.eta)
            learner_info = 'learner_{}_alpha_{}_eta_{}'.format(args.learner, args.alpha, args.eta)
        elif args.learner == 'UCB':
            beta = args.eta * 1 / (1-args.gamma) * np.sqrt(args.K*np.log((args.lamdba*((1-args.gamma)**2) + args.T*args.K)/(args.delta*args.lamdba*((1-args.gamma)**2))))
            U = int(np.log(args.T/(1-args.gamma))/(1-args.gamma)) + 1
            learner = UCB(args.S, args.A, args.K, env.R, args.gamma, env.func, args.lamdba, beta, U)
            learner_info = 'learner_{}_lamdba_{}_delta_{}_eta_{}'.format(args.learner, args.lamdba, args.delta, args.eta)
        elif args.learner == 'DQN':
            learner = DQN(args.S, args.A, args.gamma, args.hidden_size, args.lr, args.epsilon, args.batch_size)
            learner_info = 'learner_{}_hidden_size_{}_lr_{}_epsilon_{}_batch_size_{}'.format(args.learner, args.hidden_size, args.lr, args.epsilon, args.batch_size)
        else:
            raise ValueError(args.learner)
        # learning iteration
        regrets = []
        times = []
        # xts = []
        # xt = 1.0
        diff_thetas = [[],[],[],[],[],[]]
        for t in range(args.T):
            # compute the regret
            regret = env.v_star[env.state] - env.value_iteration(learner)[env.state]

            # select the action
            start = time.time()
            action = learner.select(state)
            end = time.time()
            time_select = end - start

            # play with the environment
            next_state, reward = env.step(action)
            
            # optimize model parameter
            start = time.time()
            learner.train(state, action, reward, next_state)
            print(env.theta, learner.theta)
            end = time.time()
            time_train = end - start
            
            if args.learner == 'RBMLE':
                temp, _ = learner.compute_mle()
                if temp is not None:
                    learner.theta_hat = temp

            if args.learner == 'UCB' and learner.theta_list != []:
                list_A = []
                list_2 = []
                theta_diff_list = learner.theta_list - env.theta
                for i in range(len(learner.theta_list)):
                    norm_A = np.matmul(theta_diff_list[i], np.matmul(learner.sigma, theta_diff_list[i]))
                    norm_2 = np.linalg.norm(theta_diff_list[i])**2
                    list_A.append(norm_A)
                    list_2.append(norm_2)
                    
                diff_thetas[0].append(np.min(list_A))
                diff_thetas[1].append(np.mean(list_A))
                diff_thetas[2].append(np.min(list_2))
                diff_thetas[3].append(np.mean(list_2))
            else: 
                diff_thetas[0].append(np.matmul(learner.theta - env.theta, np.matmul(learner.sigma, learner.theta - env.theta)))
                diff_thetas[1].append(diff_thetas[0][-1])
                diff_thetas[2].append(np.linalg.norm(learner.theta - env.theta))
                diff_thetas[3].append(diff_thetas[2][-1])
                
            if args.learner == 'UCB':
                theta_hat = np.matmul(np.linalg.inv(learner.sigma),learner.b)
                diff_thetas[4].append(np.matmul(theta_hat - env.theta, np.matmul(learner.sigma, theta_hat - env.theta)))
                diff_thetas[5].append(np.linalg.norm(theta_hat - env.theta)**2)
            else: 
                diff_thetas[4].append(np.matmul(learner.theta_hat - env.theta, np.matmul(learner.sigma, learner.theta_hat - env.theta)))
                diff_thetas[5].append(np.linalg.norm(learner.theta_hat - env.theta)**2)

            # transit to next state
            state = next_state

            # record the regret and computation time
            regrets.append(regret)
            times.append(time_select + time_train)
            # xts.append(xt)
            # diff_thetas.append(diff_theta)

            # pring log message 
            # print(t, xt, diff_theta, learner.theta, mle, env.theta)
        for i in range(6):
            with open('./diff_theta/diff_theta_{}_{}_env_{}_S_{}_A_{}_K_{}_gamma_{}_temperature_{}_T_{}_seed_{}_episode_{}.pkl'.format
                    (   str(i),
                        learner_info,
                        args.env,
                        args.S,
                        args.A,
                        args.K,
                        args.gamma,
                        args.temperature,
                        args.T,
                        args.seed,
                        e), 'wb') as f:
                pickle.dump(diff_thetas[i], f)
        with open('./record/regret_{}_env_{}_S_{}_A_{}_K_{}_gamma_{}_temperature_{}_T_{}_seed_{}_episode_{}.pkl'.format
                (learner_info,
                    args.env,
                    args.S,
                    args.A,
                    args.K,
                    args.gamma,
                    args.temperature,
                    args.T,
                    args.seed,
                    e), 'wb') as f:
            pickle.dump(regrets, f)

        with open('./record/time_{}_env_{}_S_{}_A_{}_K_{}_gamma_{}_temperature_{}_T_{}_seed_{}_episode_{}.pkl'.format
                (learner_info,
                    args.env,
                    args.S,
                    args.A,
                    args.K,
                    args.gamma,
                    args.temperature,
                    args.T,
                    args.seed,
                    e), 'wb') as f:
            pickle.dump(times, f)

    # draw the regret curve and save the record.
    # cumulative_regret = np.cumsum(np.mean(episode_regret, axis = 0))
    # plt.plot(np.arange(args.T), cumulative_regret, label = args.learner)
    # plt.legend()
    # plt.show()
    # plt.savefig('./record/{}_env_{}_S_{}_A_{}_K_{}_gamma_{}_temperature_{}_T_{}_Episodes_{}_seed_{}.png'.format
    #            (learner_info,
    #             args.env,
    #             args.S,
    #             args.A,
    #             args.K,
    #             args.gamma,
    #             args.temperature,
    #             args.T,
    #             args.Episodes,
    #             args.seed))
    
    # plt.plot(np.arange(args.T), cumulative_regret, label = args.learner)
    
        
        
        # with open('./record/xt_{}_env_{}_S_{}_A_{}_K_{}_gamma_{}_temperature_{}_T_{}_seed_{}_episode_{}.pkl'.format
        #         (learner_info,
        #             args.env,
        #             args.S,
        #             args.A,
        #             args.K,
        #             args.gamma,
        #             args.temperature,
        #             args.T,
        #             args.seed,
        #             e), 'wb') as f:
        #     pickle.dump(xts, f)
        
