import pickle
import argparse
import argparse
import random
import numpy as np
import torch
import transformers
import tqdm
from modeling_gpt2 import GPT2Model
import sys, os
import Environment.env
import Environment.function_preprocessing
import os
# from torch.utils.tensorboard import SummaryWriter
from torch import nn
from scipy.stats import rankdata
import copy
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from scipy.stats import rankdata
import sobol_seq

class QT(torch.nn.Module):
    def __init__(self, 
                 T, 
                 domain_size, 
                 f_num, 
                 gamma = 0.98, 
                 hidden_size = 128, 
                 lr = 0.01, 
                 weight_decay = 0.1, 
                 n_layer = 4, 
                 n_head = 4, 
                 n_batch = 1,
                 n_positions = 301,
                 warmup_steps=1000,
                 update_freq = 1, 
                 optimizer = "sgd",
                 dropout = 0.1,
                 epsilon = 0.1,
                 target_update_freq = 1,
                 batch_size = 16,
                 initial_sample = 10,
                 device = "cpu"):
        super(QT, self).__init__()

        self.device = device
        self.update_freq = update_freq
        self.hidden_size = hidden_size
        self.lr = lr
        self.weight_decay = weight_decay
        
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            n_layer=n_layer,
            n_head=n_head,
            n_inner=4*hidden_size,
            activation_function='relu',
            n_positions=n_positions,
            resid_pdrop=dropout,
            attn_pdrop=dropout,
            embd_pdrop=dropout
        )
        self.n_positions = n_positions
        self.n_batch = n_batch
        self.epsilon = epsilon # for exploration
        self.target_update_freq = target_update_freq
        self.batch_size = batch_size
        self.initial_sample = initial_sample
        self.max_length = T
        self.domain_size = domain_size
        self.transformer = GPT2Model(config)
        self.state_action_dim = f_num * 3 + 1
        self.gamma = gamma
        self.embed_reward = torch.nn.Linear(1, hidden_size)
        self.embed_q_value = torch.nn.Linear(1, hidden_size)
        self.embed_time = torch.nn.Linear(1, hidden_size)
        self.embed_state_action = torch.nn.Linear(f_num * 3 + 1, hidden_size)

        self.predict_state_action = torch.nn.Linear(hidden_size, self.state_action_dim)
        self.predict_q_value = torch.nn.Linear(hidden_size, 1)
        self.predict_reward = torch.nn.Linear(hidden_size, 1)

        self.trajectory_buffer = []
        self.trajectory = []

        if optimizer == "sgd":
            self.optimizer = torch.optim.SGD(params=self.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
        elif optimizer == "adam":
            self.optimizer = torch.optim.AdamW(params=self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError("Unknown optimizer %s" % optimizer)
        
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lambda steps: min((steps+1)/warmup_steps, 1))
        

    def forward_without_q(self, state_actions, rewards):

        q_values = []
        pred_len = state_actions.size()[1]
        for i in range(0, pred_len):
            time_embeddings = self.embed_time(torch.from_numpy(np.arange(i+1)).float().reshape(1,i+1,1).to(self.device))
            state_action_embeddings = self.embed_state_action(state_actions[:,:i+1,:].to(self.device)) + time_embeddings

            if i == 0:
                transformer_outputs = self.transformer(inputs_embeds=state_action_embeddings)
                x = transformer_outputs['last_hidden_state']
                q_values.append(self.predict_q_value(x)[:,:,:])
            else:
                q_value_embeddings = self.embed_q_value(torch.cat(q_values, axis = 1)) # (sequence_length - 1, hidden_size)
                reward_embeddings = self.embed_reward(rewards[:,:i,:].to(self.device)) # (sequence_length - 1, hidden_size)
                inputs = torch.stack((state_action_embeddings[:,:i,:], q_value_embeddings, reward_embeddings), dim=2).reshape(state_action_embeddings.size(0), i*3, self.hidden_size)
                # this makes the sequence look like ((s_1, a_1), r_1, Q_1, (s_2, a_2), ...)
                inputs = torch.cat((inputs, state_action_embeddings[:,-1:,:]), dim=1) # (3 * sequence_length - 2, hidden_size)
                # feed in the input embeddings (not word indices as in NLP) to the model
                
                # fit the n_position size
                if inputs.size()[1] > self.n_positions:
                    inputs = inputs[:,-self.n_positions:,:]

                transformer_outputs = self.transformer(inputs_embeds=inputs)
                x = transformer_outputs['last_hidden_state'] # (3 * sequence_length - 2, hidden_size)

                # get predictions
                q_value_pred = self.predict_q_value(torch.index_select(x.cpu(), 1, torch.arange(0,x.size(1),3,dtype=int)).to(self.device)) 
                q_values.append(q_value_pred[:,-1:,:])
        return torch.cat(q_values, axis = 1)

    def forward(self, state_actions, rewards, q_values):

        if rewards == None or q_values == None:
            state_actions = torch.from_numpy(state_actions).float().unsqueeze(0).unsqueeze(0).to(self.device)
            state_action_embeddings = self.embed_state_action(state_actions)
            transformer_outputs = self.transformer(inputs_embeds=state_action_embeddings)
            x = transformer_outputs['last_hidden_state']
            q_value_pred = self.predict_q_value(x)  
            return q_value_pred
        
        batch_size = state_actions.size(0)
        pred_len = state_actions.size(1)

        # embed each modality with a different head
        time_embeddings = self.embed_time(torch.from_numpy(np.arange(pred_len)).float().unsqueeze(1).to(self.device)) # (sequence_length + 1, hidden_size)
        state_action_embeddings = self.embed_state_action(state_actions) + torch.tile(time_embeddings[:pred_len,:], (batch_size,1,1)) # (sequence_length, hidden_size)
        q_value_embeddings = self.embed_q_value(q_values) + torch.tile(time_embeddings[:pred_len-1,:], (batch_size,1,1))# (sequence_length - 1, hidden_size)
        reward_embeddings = self.embed_reward(rewards) + torch.tile(time_embeddings[:pred_len-1,:], (batch_size,1,1))# (sequence_length - 1, hidden_size)
        
        inputs = torch.stack((state_action_embeddings[:,:pred_len-1,:], q_value_embeddings, reward_embeddings), dim=2).reshape(state_action_embeddings.size(0), (pred_len-1)*3, self.hidden_size)
        # this makes the sequence look like ((s_1, a_1), r_1, Q_1, (s_2, a_2), ...)
        inputs = torch.cat((inputs, state_action_embeddings[:,-1:,:]), dim=1) # (3 * sequence_length - 2, hidden_size)
        # feed in the input embeddings (not word indices as in NLP) to the model
        
        # fit the n_position size
        if inputs.size()[1] > self.n_positions:
            inputs = inputs[:,-self.n_positions:,:]

        transformer_outputs = self.transformer(inputs_embeds=inputs)
        x = transformer_outputs['last_hidden_state'] # (3 * sequence_length - 2, hidden_size)

        # get predictions
        q_value_pred = self.predict_q_value(torch.index_select(x.cpu(), 1, torch.arange(0,x.size(1),3,dtype=int)).to(self.device)) 

        return q_value_pred
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    

    def select_action(self, state_actions, target_network, N_m = 1):
        domain_size = np.shape(state_actions)[0]
        # first action is random
        if len(self.trajectory) < self.initial_sample:
            action = random.randint(0, domain_size - 1)
            return [action]
        
        previous_state_actions = torch.from_numpy(np.array([b["state_action"] for b in self.trajectory]).reshape(1,len(self.trajectory),self.state_action_dim)).float()
        previous_rewards = torch.from_numpy(np.array([b["reward"] for b in self.trajectory]).reshape(1,len(self.trajectory),1)).float()
        self.previous_q_values = target_network.forward_without_q(previous_state_actions, previous_rewards).detach()
        state_actions = torch.from_numpy(state_actions).float()
    
        batch_state_actions = torch.cat((torch.tile(previous_state_actions,(domain_size,1,1)), state_actions.unsqueeze(1)), axis=1).to(self.device)
        # (domain_size, sequence_length, state_action_dim)
        batch_rewards = torch.tile(previous_rewards, (domain_size,1,1)).to(self.device)
        # (domain_size, sequence_length - 1, 1)
        batch_q_values = torch.tile(self.previous_q_values, (domain_size,1,1))
        # (domain_size, sequence_length - 1, 1)
        q_values = self.forward(batch_state_actions, batch_rewards, batch_q_values).detach()

        # select the best action based on the q_values
        dist = torch.distributions.Categorical(logits = q_values[:,-1,:].cpu().squeeze().double()*self.domain_size)
        action = dist.sample()
        return [action]
        # return sorted(range(len(q_values[:,-1,:].cpu())), key=lambda i: q_values[:,-1,:].cpu()[i])[-N_m:]

    def update(self, target_network, seed=0): 
        loss_value = 0
        self.optimizer.zero_grad()

        batch_size = min(self.batch_size, len(self.trajectory_buffer))
        batch = random.sample(self.trajectory_buffer, batch_size)

        # get data from the batch
        state_actions = []
        rewards = []
        for traj in batch:
            state_action = []
            reward = []
            for transit in traj:
                state_action.append(transit['state_action'])
                reward.append(np.array([transit['reward']]))
            state_actions.append(np.array(state_action))
            rewards.append(np.array(reward))

        state_actions = torch.from_numpy(np.array(state_actions)).float()
        rewards = torch.from_numpy(np.array(rewards)).float()
        next_q_values = target_network.forward_without_q(state_actions, rewards).detach()
        
        # assign the final q value to be zero 
        next_q_values[:,-1,:] = 0

        state_actions = state_actions[:,:-1,:].to(self.device)
        # (batch_size, sequence_length, state_action_dim)
        rewards = rewards[:,:-1,:].to(self.device)
        # (batch_size, sequence_length, 1)
        next_q_values = next_q_values[:,1:,:] # ignore first q value
        # (batch_size, sequence_length, 1)
        q_values = self.forward(state_actions, rewards[:,:self.max_length-1,:], next_q_values[:,:self.max_length-1,:])

        if q_values.size()[1] < next_q_values.size()[1]:
            next_q_values = next_q_values[:,-q_values.size()[1]:,:]
            rewards = rewards[:,-q_values.size()[1]:,:]
        loss = torch.sum((rewards[:,1:,:] + self.gamma * next_q_values[:,1:,:] - q_values[:,1:,:]) ** 2)
            
        # loss.backward()
        # self.optimizer.step()
        # self.scheduler.step()
        loss_value = loss.detach().cpu().item()
        self.ranked_list = rankdata(q_values[:,-1,:].cpu(), method='max')
        
        return loss_value

    def reset_trajectory(self):
        self.trajectory = []
        # if self.device != 'cpu': 
        #     torch.cuda.empty_cache()

def args_to_info(args):
    learner_info = "{}_{}_domain_{}_T_{}_fnum_{}_gamma_{}_hidden_{}_lr_{}_weight_decay_{}_seed_{}_n_layer_{}_n_head_{}_n_positions_{}_dropout_{}_epsilon_{}_target_update_freq_{}_batch_size_{}_demo_rate_{}_buffer_size_{}_new_reward_{}_sample_rate_{}".format(
        args.model_episode,
        args.model_type,
        args.domain_size, 
        args.T, 
        args.f_num, 
        args.gamma,
        args.hidden_size,
        args.lr,
        args.weight_decay,
        args.seed,
        args.n_layer,
        args.n_head,
        args.n_positions,
        args.dropout,
        args.epsilon,
        args.target_update_freq,
        args.batch_size,
        args.demo_rate,
        args.buffer_size,
        args.new_reward,
        args.sample_rate
        )
    return learner_info

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Multi-objective BO')
    parser.add_argument('--device', type=str, default="0", help='gpu device')
    parser.add_argument('--env', type=str, default='BO')
    parser.add_argument('--testing_episode', type=int, default=100)
    parser.add_argument('--T', type=int, default=100)
    parser.add_argument('--domain_size', type=int, default=1000)
    parser.add_argument('--f_num', type=int, default=2)
    parser.add_argument('--hidden_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.00001)
    parser.add_argument('--weight_decay', type=float, default=0.00001)
    parser.add_argument('--n_batch', type=int, default=1)
    parser.add_argument('--update_freq', type=int, default=1)
    parser.add_argument('--n_layer', type=int, default=4)
    parser.add_argument('--n_head', type=int, default=4)
    parser.add_argument('--target_update_freq', type=int, default=5)
    parser.add_argument('--n_positions', type=int, default=301)
    parser.add_argument('--warmup_steps', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.95)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--function_type', type=str, default='ARa')
    parser.add_argument('--model_episode', type=int, default=2000)
    parser.add_argument('--model_type', type=str, default="train")
    parser.add_argument('--optimizer', type=str, default="adam")
    parser.add_argument('--dropout', type=float, default = 0.1)
    parser.add_argument('--epsilon', type=float, default = 0.1)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--sample_rate', type=float, default=0.005)
    parser.add_argument('--demo_rate', type=float, default=0.01)
    parser.add_argument('--buffer_size', type=int, default=64)
    parser.add_argument('--new_reward',type=int, default=1)
    parser.add_argument('--ls_learned_freq', type=int, default=10, help='freq of learning ls')
    parser.add_argument('--noise_level', type=float, default=0.1, help='perturbed and observation noise')
    parser.add_argument('--update_step', type=int, default=10, help='# of GD steps')
    parser.add_argument('--N_m', type=int, default=1, help='N_m')
    parser.add_argument('--N_local', type=int, default=1, help='N_local')
    parser.add_argument('--initial_sample', type=int, default=1, help='# of initial sample')
    parser.add_argument('--online_ls', type=int, default=1, help='ls in testing')
    parser.add_argument('--ls_weight', type=float, default=1, help='make ls to be smaller')

    args = parser.parse_args()
    learner_info = args_to_info(args)
    print(args.function_type, learner_info)
    
    record_index = -1
    with open('BOFormer_record.txt', 'r') as f:
        for line in f.readlines():
            index, namespace_str = line.split('\t', 1)
            args_str = namespace_str[10:-2]  # remove 'Namespace(' and ')'
            args_list = args_str.split(', ')
            args_dict = {arg.split('=')[0]: arg.split('=')[1] for arg in args_list}
            args_dict['model_episode'] = int(args.model_episode)
            args_dict['model_type'] = args.model_type
            args_dict = argparse.Namespace(**args_dict)
            c_info = args_to_info(args_dict)
            print(c_info)
            if c_info == learner_info:
                record_index = int(index)
                break
    
    print("record index: ", record_index)
    if args.device != "cpu":
        device = torch.device("cuda:" + args.device if torch.cuda.is_available() else "cpu")
    else: 
        device = "cpu"
    #os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    torch.set_num_threads(8)
    torch.set_num_interop_threads(8)

    # set seed for reproduc
    if args.seed > 0:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)
    

    learner = QT(T = args.T, 
                 domain_size = args.domain_size, 
                 f_num = args.f_num, 
                 gamma = args.gamma, 
                 hidden_size = args.hidden_size, 
                 lr = args.lr, 
                 weight_decay = args.weight_decay, 
                 n_layer = args.n_layer, 
                 n_head = args.n_head,
                 n_batch = args.n_batch,
                 n_positions = args.n_positions,
                 warmup_steps = args.warmup_steps,
                 update_freq = args.update_freq,
                 optimizer = args.optimizer,
                 dropout = args.dropout,
                 epsilon = args.epsilon,
                 target_update_freq = args.target_update_freq,
                 batch_size = args.batch_size,
                 initial_sample = args.initial_sample,
                 device = device).to(device)
     
    if args.model_episode > 0:
        learner.load_state_dict(torch.load("./BOFormer_models/model{}/{}.pth".format(record_index, args.model_episode), map_location=device))
    learner.eval()

    target_network = QT(T = args.T, 
                 domain_size = args.domain_size, 
                 f_num = args.f_num, 
                 gamma = args.gamma, 
                 hidden_size = args.hidden_size, 
                 lr = args.lr, 
                 weight_decay = args.weight_decay, 
                 n_layer = args.n_layer, 
                 n_head = args.n_head,
                 n_batch = args.n_batch,
                 n_positions = args.n_positions,
                 warmup_steps = args.warmup_steps,
                 update_freq = args.update_freq,
                 optimizer = args.optimizer,
                 dropout = args.dropout,
                 epsilon = args.epsilon,
                 target_update_freq = args.target_update_freq,
                 batch_size = args.batch_size,
                 device = device).to(device)
    
    target_network.load_state_dict(learner.state_dict())
    target_network.eval()

    
    env = Environment.env.Environment(T = args.T, 
                      domain_size = args.domain_size, 
                      f_num = args.f_num, 
                      function_type = args.function_type, 
                      seed = args.seed,
                      new_reward = args.new_reward,
                      noise_level = args.noise_level,
                      ls_learned_freq = args.ls_learned_freq,
                      online_ls=args.online_ls,
                      ls_weight = args.ls_weight)
    
    for e in range(args.testing_episode):
        # initialization
        seed=args.seed+e*10
        env.reset(seed=seed)
        env.history['info'] = str(args)
        sas = [] # record
        # initial sample
        X = Environment.function_preprocessing.domain(args.function_type, args.domain_size, seed) 
        y_star, reward, regret = env.step(X[random.randint(0,args.domain_size-1)])
        gp = env.fit_gp(0)
        state_actions = Environment.env.construct_state_action_pair(X, gp, y_star, 0)
        # X = Environment.function_preprocessing.domain(args.function_type, args.domain_size, seed) 
        # y_star, reward, regret = env.step(X[random.randint(0,args.domain_size-1)])
        # gp = Environment.env.GaussianProcess(np.array(env.history["x"]), 
        #                                              np.array(env.history["y_observed"]), 
        #                                              env.kernel, 
        #                                              env.kernel_ls, 
        #                                              env.f_num)
        # state_actions = gp.construct_state_action_pair(X, y_star, 0/args.T)

        # record transition
        learner.trajectory.append({"state_action": np.array([0.0]*(3*args.f_num+1)), "reward": float(reward), "next_state_actions": state_actions})
        sas.append(state_actions)
        # training iterations
        for t in tqdm.tqdm(range(1, args.T)):            
            
            # select action
            
            actions = learner.select_action(state_actions, target_network, N_m  = args.N_m)
            if len(actions) == 1:
                action = actions[0]
                # env update
                y_star, reward, regret = env.step(X[action])
            else:
                candidate = []
                for action in actions:
                    candidate.append(X[action] + 0.05*(sobol_seq.i4_sobol_generate(env.domain_dim, args.N_local, seed+t)-0.5))
                candidate = np.clip(np.concatenate(candidate, axis = 0), 0, 1)
                state_action_pairs = Environment.env.construct_state_action_pair(candidate, gp, y_star, t/args.T)
                action = learner.select_action(state_action_pairs, target_network, N_m = 1)[0]
                # env update
                y_star, reward, regret = env.step(candidate[action])
            # learn ls for GP
            gp = env.fit_gp(t)
                
            X = Environment.function_preprocessing.domain(args.function_type, args.domain_size, seed+t) 
            next_state_actions = Environment.env.construct_state_action_pair(X, gp, y_star, t/args.T)
            # gp = Environment.env.GaussianProcess(np.array(env.history["x"]), 
            #                                          np.array(env.history["y_observed"]), 
            #                                          env.kernel, 
            #                                          env.kernel_ls, 
            #                                          env.f_num)
            # gp.fit(np.array(env.history["x"]),np.array(env.history["y_observed"]))
            # next_state_actions = gp.construct_state_action_pair(X, y_star, t/args.T)

            # record transition
            learner.trajectory.append({"state_action": state_actions[action], "reward": reward})
            
            # update current state
            state_actions = next_state_actions
            sas.append(state_actions)
            # actions_record.append(state_actions[action])

        # record final transition
        learner.trajectory.append({"state_action": state_actions[action], "reward": 0.0})

        # learner update
        learner.trajectory_buffer.append(learner.trajectory)
        # loss = learner.update(target_network)
        learner.reset_trajectory()
        # env.history["sa"] = sas

        print('EP:{} | R: {:.3f}'.format(e, regret))
        filename = '{}_function_type_{}_N_m_{}_N_local_{}_ls_learned_freq_{}_initial_sample_{}_online_ls_{}_episode_{}.pkl'.format(
            "BOFormer", 
            args.function_type,
            args.N_m,
            args.N_local,
            args.ls_learned_freq,
            args.initial_sample,
            args.online_ls,
            e)
        #  # save regrets
        os.makedirs(f'./BOFormer_testings/model{record_index}', exist_ok=True)
        # filename = 'function_type_{}_episode_{}.pkl'.format(args.function_type, e)
        with open(os.path.join(f'./BOFormer_testings/model{record_index}', filename), 'wb') as f:
            pickle.dump(env.history, f)
