import pickle
import argparse
import random
import numpy as np
import torch
import transformers
import tqdm
import math
from modeling_gpt2 import GPT2Model
# from transformers import GPT2Model
import os
import sys
import copy
import Environment.env
import Environment.function_preprocessing
import pathlib
from torch.utils.tensorboard import SummaryWriter
import wandb
import uuid
from dataclasses import asdict, dataclass
from torch import nn
from ehi import ExpectedHypervolumeImprovement
from prioritized_memory import Memory
from botorch.utils.multi_objective.box_decompositions.non_dominated import FastNondominatedPartitioning
from botorch.acquisition.multi_objective.monte_carlo import qExpectedHypervolumeImprovement, qNoisyExpectedHypervolumeImprovement
from botorch.acquisition.multi_objective.joint_entropy_search import qLowerBoundMultiObjectiveJointEntropySearch
from botorch.optim.optimize import optimize_acqf

real = ["bicycle", "bonsai", "flowers", "garden", "room", "stump", "treehill"]
synthetic = ["chair", "drums", "ficus", "hotdog", "lego", "materials", "mic", "ship"]

class QT(torch.nn.Module):
    def __init__(self, 
                 T, 
                 domain_size, 
                 f_num, 
                 gamma = 0.98, 
                 hidden_size = 128, 
                 lr = 0.01, 
                 weight_decay = 0.1, 
                 n_layer = 4, 
                 n_head = 4, 
                 n_batch = 1,
                 n_positions = 301,
                 warmup_steps=1000,
                 update_freq = 1, 
                 optimizer = "sgd",
                 dropout = 0.1,
                 epsilon = 0.1,
                 target_update_freq = 1,
                 batch_size = 16,
                 sample_rate = 0.1,
                 buffer_size = 64,
                 initial_sample = 1,
                 temperature = 1000,
                 device = "cpu"):
        super(QT, self).__init__()

        self.device = device
        self.update_freq = update_freq
        self.hidden_size = hidden_size
        self.lr = lr
        self.weight_decay = weight_decay
        
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            n_layer=n_layer,
            n_head=n_head,
            n_inner=4*hidden_size,
            activation_function='relu',
            n_positions=n_positions,
            resid_pdrop=dropout,
            attn_pdrop=dropout,
            embd_pdrop=dropout
        )
        self.buffer_size = buffer_size
        self.sample_rate = sample_rate
        self.n_positions = n_positions
        self.n_batch = n_batch
        self.epsilon = epsilon # for exploration
        self.target_update_freq = target_update_freq
        self.batch_size = batch_size
        self.initial_sample = initial_sample
        self.temperature = temperature
        self.max_length = T
        self.domain_size = domain_size
        self.transformer = GPT2Model(config)
        self.state_action_dim = f_num * 3 + 1
        self.gamma = gamma
        self.embed_reward = torch.nn.Linear(1, hidden_size)
        self.embed_q_value = torch.nn.Linear(1, hidden_size)
        self.embed_time = torch.nn.Linear(1, hidden_size)
        self.embed_state_action = torch.nn.Linear(f_num * 3 + 1, hidden_size)

        self.predict_state_action = torch.nn.Linear(hidden_size, self.state_action_dim)
        self.predict_q_value = torch.nn.Linear(hidden_size, 1)
        self.predict_reward = torch.nn.Linear(hidden_size, 1)

        self.trajectory_buffer = Memory(buffer_size)
        self.trajectory = []

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=1.0-(1.0-self.gamma), std=0.02/math.sqrt(2 * n_layer))
                #torch.nn.init.normal_(p, mean=10.0, std=0.02/math.sqrt(2 * n_layer))
        # report number of parameters
        print("number of parameters: %.2fM" % (sum(p.numel() for p in self.parameters())/1e6,))

        if optimizer == "sgd":
            self.optimizer = torch.optim.SGD(params=self.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
        elif optimizer == "adam":
            self.optimizer = torch.optim.AdamW(params=self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError("Unknown optimizer %s" % optimizer)
        
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lambda steps: min((steps+1)/warmup_steps, 1))
        

    def forward_without_q(self, state_actions, rewards, next_state_actions):
        q_values = []
        pred_len = state_actions.size()[1]
        time_embeddings = self.embed_time(torch.from_numpy(np.arange(pred_len)).float().reshape(1,pred_len,1).to(self.device))
        state_action_embeddings = self.embed_state_action(state_actions.to(self.device)) + time_embeddings
        reward_embeddings = self.embed_reward(rewards.to(self.device)) + time_embeddings
        for i in range(pred_len):
            if i == 0:
                transformer_outputs = self.transformer(inputs_embeds=state_action_embeddings[:,:i+1,:])
                x = transformer_outputs['last_hidden_state']
                q_values.append(self.predict_q_value(x).detach())
            else:
                q_value_embeddings = self.embed_q_value(torch.cat(q_values, axis = 1)) + time_embeddings[:,:i,:]# (sequence_length - 1, hidden_size)
                inputs = torch.stack((state_action_embeddings[:,:i,:], q_value_embeddings, reward_embeddings[:,:i,:]), dim=2).reshape(state_action_embeddings.size(0), i*3, self.hidden_size)
                # this makes the sequence look like ((s_1, a_1), r_1, Q_1, (s_2, a_2), ...)
                inputs = torch.cat((inputs, state_action_embeddings[:,i,:].unsqueeze(1)), dim=1) # (3 * sequence_length - 2, hidden_size)
                # feed in the input embeddings (not word indices as in NLP) to the model
                
                # fit the n_position size
                if inputs.size()[1] > self.n_positions:
                    inputs = inputs[:,-self.n_positions:,:]

                transformer_outputs = self.transformer(inputs_embeds=inputs)
                x = transformer_outputs['last_hidden_state'] # (3 * sequence_length - 2, hidden_size)

                # get predictions
                q_value_pred = self.predict_q_value(torch.index_select(x.cpu(), 1, torch.arange(3,x.size(1),3,dtype=int)).to(self.device)) 
                q_values.append(q_value_pred[:,-1:,:].detach())
        
        q_values = torch.cat(q_values, axis = 1)
        if next_state_actions is None:
            return q_values
        
        # next_state_actions, (batch_size, pred_len, domain_size, state_action_dim)
        candidate_q_values = []
        pred_len = state_actions.size()[1]
        for i in range(pred_len):
            temp = []
            candidate_points = random.sample(range(self.domain_size), int(self.domain_size*self.sample_rate))
            # candidate_points = random.sample(range(self.domain_size), int(1)) # real
            if i != 0:
                q_value_embeddings = self.embed_q_value(torch.cat(candidate_q_values, axis = 1)) + time_embeddings[:,:i,:] # (sequence_length - 1, hidden_size)
                
            for j in candidate_points:
                temp_state_actions = torch.from_numpy(np.array(next_state_actions))[:,i:i+1,j,:].float()
                state_action_embeddings_new = self.embed_state_action(temp_state_actions.to(self.device)) + time_embeddings[:,i:i+1,:]
                if i == 0:
                    state_action_embeddings_new = torch.cat((state_action_embeddings[:,:i,:], state_action_embeddings_new), dim = 1)
                    transformer_outputs = self.transformer(inputs_embeds=state_action_embeddings_new)
                    x = transformer_outputs['last_hidden_state']
                    temp.append(self.predict_q_value(x).detach())
                else:
                    
                    inputs = torch.stack((state_action_embeddings[:,:i,:], q_value_embeddings, reward_embeddings[:,:i,:]), dim=2).reshape(state_action_embeddings.size(0), i*3, self.hidden_size)
                    # this makes the sequence look like ((s_1, a_1), r_1, Q_1, (s_2, a_2), ...)
                    inputs = torch.cat((inputs, state_action_embeddings_new), dim=1) # (3 * sequence_length - 2, hidden_size)
                    # feed in the input embeddings (not word indices as in NLP) to the model

                    # fit the n_position size
                    if inputs.size()[1] > self.n_positions:
                        inputs = inputs[:,-self.n_positions:,:]

                    transformer_outputs = self.transformer(inputs_embeds=inputs)
                    x = transformer_outputs['last_hidden_state'] # (3 * sequence_length - 2, hidden_size)

                    # get predictions
                    q_value_pred = self.predict_q_value(torch.index_select(x.cpu(), 1, torch.arange(3,x.size(1),3,dtype=int)).to(self.device)) 
                    temp.append(q_value_pred[:,-1:,:].detach())
            candidate_q_values.append(torch.max(torch.stack(temp),dim = 0)[0])
        candidate_q_values = torch.cat(candidate_q_values, axis = 1)
        return torch.max(torch.cat((candidate_q_values, q_values), axis=2), axis = 2)[0].unsqueeze(2)

    def forward(self, state_actions, rewards, q_values):

        if rewards == None or q_values == None:
            state_actions = torch.from_numpy(state_actions).float().unsqueeze(0).unsqueeze(0).to(self.device)
            state_action_embeddings = self.embed_state_action(state_actions)
            transformer_outputs = self.transformer(inputs_embeds=state_action_embeddings)
            x = transformer_outputs['last_hidden_state']
            q_value_pred = self.predict_q_value(x)  
            return q_value_pred
        
        batch_size = state_actions.size(0)
        pred_len = state_actions.size(1)

        # embed each modality with a different head
        time_embeddings = self.embed_time(torch.from_numpy(np.arange(pred_len)).float().unsqueeze(1).to(self.device)).unsqueeze(0) # (sequence_length + 1, hidden_size)
        state_action_embeddings = self.embed_state_action(state_actions) + torch.tile(time_embeddings, (batch_size,1,1)) # (sequence_length, hidden_size)
        q_value_embeddings = self.embed_q_value(q_values) + torch.tile(time_embeddings[:,:pred_len-1,:], (batch_size,1,1))# (sequence_length - 1, hidden_size)
        reward_embeddings = self.embed_reward(rewards) + torch.tile(time_embeddings[:,:pred_len-1,:], (batch_size,1,1))# (sequence_length - 1, hidden_size)
        
        inputs = torch.stack((state_action_embeddings[:,:pred_len-1,:], q_value_embeddings, reward_embeddings), dim=2).reshape(state_action_embeddings.size(0), (pred_len-1)*3, self.hidden_size)
        # this makes the sequence look like ((s_1, a_1), r_1, Q_1, (s_2, a_2), ...)
        inputs = torch.cat((inputs, state_action_embeddings[:,-1:,:]), dim=1) # (3 * sequence_length - 2, hidden_size)
        # feed in the input embeddings (not word indices as in NLP) to the model
        
        # fit the n_position size
        if inputs.size()[1] > self.n_positions:
            inputs = inputs[:,-self.n_positions:,:]

        transformer_outputs = self.transformer(inputs_embeds=inputs)
        x = transformer_outputs['last_hidden_state'] # (3 * sequence_length - 2, hidden_size)

        # get predictions
        q_value_pred = self.predict_q_value(torch.index_select(x.cpu(), 1, torch.arange(0,x.size(1),3,dtype=int)).to(self.device)) 

        return q_value_pred
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    

    def select_action(self, state_actions, target_network):
        torch.cuda.empty_cache()
        # first action is random
        if len(self.trajectory) < self.initial_sample:
            action = random.randint(0, self.domain_size - 1)
            return action

        # epsilon greedy
        e = random.random()
        if e < self.epsilon:
            action = random.randint(0, self.domain_size - 1)
            return action
        
        previous_state_actions = torch.from_numpy(np.array([b["state_action"] for b in self.trajectory]).reshape(1,len(self.trajectory),self.state_action_dim)).float()
        previous_rewards = torch.from_numpy(np.array([b["reward"] for b in self.trajectory]).reshape(1,len(self.trajectory),1)).float()
        self.previous_q_values = target_network.forward_without_q(previous_state_actions, previous_rewards, None).detach()
        state_actions = torch.from_numpy(state_actions).float()
    
        batch_state_actions = torch.cat((torch.tile(previous_state_actions,(self.domain_size,1,1)), state_actions.unsqueeze(1)), axis=1).to(self.device)
        # (domain_size, sequence_length, state_action_dim)
        batch_rewards = torch.tile(previous_rewards, (self.domain_size,1,1)).to(self.device)
        # (domain_size, sequence_length - 1, 1)
        batch_q_values = torch.tile(self.previous_q_values, (self.domain_size,1,1))
        # (domain_size, sequence_length - 1, 1)
        q_values = self.forward(batch_state_actions, batch_rewards, batch_q_values).detach()

        # select the best action based on the q_values
        dist = torch.distributions.Categorical(logits = q_values[:,-1,:].cpu().squeeze().double()*self.temperature)
        action = dist.sample()
        # action = np.argmax(q_values[:,-1,:].cpu())
        return action

    def log_grad_norm(self):
        total_norm = 0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm
    
    def update(self, target_network): 
        loss_value = 0
        self.optimizer.zero_grad()

        batch_size = min(self.batch_size, self.trajectory_buffer.tree.n_entries)
        batch, idxs, is_weights = self.trajectory_buffer.sample(batch_size)
        
        # get data from the batch
        state_actions = []
        rewards = []
        next_state_actions = [] 
        for traj in batch:
            state_action = []
            reward = []
            next_state_action = []
            for transit in traj:
                state_action.append(transit['state_action'])
                reward.append(np.array([transit['reward']]))
                next_state_action.append(transit['next_state_actions'])
            state_actions.append(np.array(state_action))
            rewards.append(np.array(reward))
            next_state_actions.append(next_state_action)

        state_actions = torch.from_numpy(np.array(state_actions)).float()
        rewards = torch.from_numpy(np.array(rewards)).float()
        # Max over actions
        # next_q_values = [torch.sum(rewards[:,i:,:]) for i in range(1,rewards.size(1))]
        next_q_values = target_network.forward_without_q(state_actions, rewards, next_state_actions).detach()
        # next_q_values = target_network.forward_without_q(state_actions, rewards, None).detach()
        
        # assign the final q value to be zero 
        next_q_values[:,-1,:] = 0
        next_q_values[:,-2,:] = rewards[:,-1,:]
        
        state_actions = state_actions.to(self.device)
        # (batch_size, sequence_length, state_action_dim)
        rewards = rewards.to(self.device)
        # (batch_size, sequence_length, 1)
        q_values = self.forward(state_actions, rewards[:,:-1,:], next_q_values[:,:-1,:])

        pred_len = q_values.size(1)
        loss = (rewards[:,-pred_len:,:] + self.gamma * next_q_values[:,-pred_len:,:] - q_values) ** 2

        loss = torch.sum(loss, axis = 1) # sum over the sequence length dimension
        # update priority
        for i in range(batch_size):
            self.trajectory_buffer.update(idxs[i], loss[i].item())
        
        loss = torch.sum(torch.FloatTensor(is_weights).to(self.device) * loss.squeeze())
        loss.backward()
        grad_norm = self.log_grad_norm()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.scheduler.step()
        loss_value = loss.detach().cpu().item()
        
        return loss_value, grad_norm

    def reset_trajectory(self):
        self.trajectory = []
        if self.device != 'cpu': 
            torch.cuda.empty_cache()
            
    def append_sample(self, trajectory):
        state_actions = []
        rewards = []
        for transit in trajectory:
            state_actions.append(transit['state_action'])
            rewards.append(np.array([transit['reward']]))
            
        state_actions = torch.from_numpy(np.array(state_actions)).unsqueeze(0).float().to(self.device)
        rewards = torch.from_numpy(np.array(rewards)).unsqueeze(0).float().to(self.device)
        next_q_values = target_network.forward_without_q(state_actions, rewards, None).detach().to(self.device)
        # next_q_values = next_q_values[:,1:,:] # ignore first q value
        q_values = self.forward(state_actions, rewards[:,:self.max_length-1,:], next_q_values[:,:self.max_length-1,:]).detach()
        pred_len = q_values.size(1)
        loss = torch.sum((rewards[:,-pred_len:,:] + self.gamma * next_q_values[:,-pred_len:,:] - q_values) ** 2)
        self.trajectory_buffer.add(loss.item(), trajectory)

def wandb_init(function_type: str) -> None:
    wandb.init(
        project="BOFormer_finetune",
        group="grad_clip",
        name=function_type,
        id=str(uuid.uuid4()),
    )
    wandb.run.save()

def args_to_info(args):
    learner_info = "{}_{}_domain_{}_fnum_{}_gamma_{}_hidden_{}_lr_{}_weight_decay_{}_seed_{}_n_layer_{}_n_head_{}_n_positions_{}_dropout_{}_epsilon_{}_target_update_freq_{}_batch_size_{}_demo_rate_{}_buffer_size_{}_new_reward_{}_sample_rate_{}".format(
        args.model_episode,
        args.model_type,
        args.domain_size,  
        args.f_num, 
        args.gamma,
        args.hidden_size,
        args.lr,
        args.weight_decay,
        args.seed,
        args.n_layer,
        args.n_head,
        args.n_positions,
        args.dropout,
        args.epsilon,
        args.target_update_freq,
        args.batch_size,
        args.demo_rate,
        args.buffer_size,
        args.new_reward,
        args.sample_rate
        )
    return learner_info

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Multi-objective BO')
    parser.add_argument('--device', type=str, default="1", help='gpu device')
    parser.add_argument('--env', type=str, default='BO')
    parser.add_argument('--testing_episode', type=int, default=100)
    parser.add_argument('--T', type=int, default=30)
    parser.add_argument('--domain_size', type=int, default=1000)
    parser.add_argument('--f_num', type=int, default=2)
    parser.add_argument('--hidden_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.00001)
    parser.add_argument('--weight_decay', type=float, default=0.00001)
    parser.add_argument('--n_batch', type=int, default=1)
    parser.add_argument('--update_freq', type=int, default=1)
    parser.add_argument('--n_layer', type=int, default=4)
    parser.add_argument('--n_head', type=int, default=4)
    parser.add_argument('--target_update_freq', type=int, default=5)
    parser.add_argument('--n_positions', type=int, default=31)
    parser.add_argument('--warmup_steps', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.95)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--function_type', type=str, default='NERF_synthetic')
    parser.add_argument('--model_episode', type=int, default=1000)
    parser.add_argument('--model_type', type=str, default="train_large")
    parser.add_argument('--optimizer', type=str, default="adam")
    parser.add_argument('--dropout', type=float, default = 0.1)
    parser.add_argument('--epsilon', type=float, default = 0.1)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--sample_rate', type=float, default=0.01)
    parser.add_argument('--demo_rate', type=float, default=0.01)
    parser.add_argument('--buffer_size', type=int, default=128)
    parser.add_argument('--new_reward',type=int, default=1)
    parser.add_argument('--ls_learned_freq', type=int, default=10, help='freq of learning ls')
    parser.add_argument('--perturb_noise_level', type=float, default=0.01, help='perturbed noise')
    parser.add_argument('--observation_noise_level', type=float, default=0.0, help='observation noise')
    parser.add_argument('--update_step', type=int, default=10, help='# of GD steps')
    parser.add_argument('--N_m', type=int, default=1, help='N_m')
    parser.add_argument('--N_local', type=int, default=1, help='N_local')
    parser.add_argument('--initial_sample', type=int, default=1, help='# of initial sample')
    parser.add_argument('--online_ls', type=int, default=1, help='ls in testing')
    parser.add_argument('--temperature', type=int, default=1000, help='temperature of softmax policy')
    parser.add_argument('--domain_dim', type=int, default=2, help='domain dimension')
    parser.add_argument('--yahpo_scenario', type=str, default='lcbench', help='lcbench, rbv2_xgboost, rbv2_svm, rbv2_glmnet')
    parser.add_argument('--NERF_scene', type=str, default="bicycle", help='NERF_scene')
    parser.add_argument('--tuning_episode', type=int, default=4, help='tuning episode')
    parser.add_argument('--record_idx', type=int, default=0, help='record_index')
    parser.add_argument('--demo', type=int, default=0, help='demo policy')
    parser.add_argument('--demo_policy', type=str, default="qNEHVI", help='demo policy')
    args = parser.parse_args()
    learner_info = args_to_info(args)
    print(args.function_type, learner_info)
    
    record_index = args.record_idx
    with open('BOFormer_record.txt', 'r') as f:
        for line in f.readlines():
            index, namespace_str = line.split('\t', 1)
            args_str = namespace_str[10:-2]  # remove 'Namespace(' and ')'
            args_list = args_str.split(', ')
            args_dict = {arg.split('=')[0]: arg.split('=')[1] for arg in args_list}
            args_dict['model_episode'] = int(args.model_episode)
            args_dict['model_type'] = args_dict['function_type'][1:-1]
            args_dict = argparse.Namespace(**args_dict)
            c_info = args_to_info(args_dict)
            print(c_info)
            if c_info == learner_info:
                record_index = int(index)
                break
    
    if args.record_idx != -1:
        record_index = args.record_idx

    print("record index: ", record_index)
    if args.device != "cpu":
        device = torch.device("cuda:" + args.device if torch.cuda.is_available() else "cpu")
    else: 
        device = "cpu"
    #os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    torch.set_num_threads(8)
    torch.set_num_interop_threads(8)

    # set seed for reproduc
    if args.seed > 0:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    env = Environment.env.Environment(T = args.T, 
                domain_size = args.domain_size, 
                f_num = args.f_num, 
                function_type = args.function_type, 
                seed = args.seed,
                new_reward = args.new_reward,
                perturb_noise_level = args.perturb_noise_level,
                observation_noise_level = args.observation_noise_level,
                domain_dim = args.domain_dim,
                online_ls=args.online_ls,
                yahpo_scenario=args.yahpo_scenario,
                NERF_scene= args.NERF_scene)
    
    learner = QT(T = args.T, 
                 domain_size = args.domain_size, 
                 f_num = args.f_num, 
                 gamma = args.gamma, 
                 hidden_size = args.hidden_size, 
                 lr = args.lr, 
                 weight_decay = args.weight_decay, 
                 n_layer = args.n_layer, 
                 n_head = args.n_head,
                 n_batch = args.n_batch,
                 n_positions = args.n_positions,
                 warmup_steps = args.warmup_steps,
                 update_freq = args.update_freq,
                 optimizer = args.optimizer,
                 dropout = args.dropout,
                 epsilon = args.epsilon,
                 target_update_freq = args.target_update_freq,
                 batch_size = args.batch_size,
                 sample_rate = args.sample_rate,
                 initial_sample = args.initial_sample,
                 temperature = args.temperature,
                 device = device).to(device)
     
    if args.model_episode > 0:
        learner.load_state_dict(torch.load("./BOFormer_models/model{}/{}.pth".format(record_index, args.model_episode), map_location=device))
    learner.train()

    target_network = QT(T = args.T, 
                 domain_size = args.domain_size, 
                 f_num = args.f_num, 
                 gamma = args.gamma, 
                 hidden_size = args.hidden_size, 
                 lr = args.lr, 
                 weight_decay = args.weight_decay, 
                 n_layer = args.n_layer, 
                 n_head = args.n_head,
                 n_batch = args.n_batch,
                 n_positions = args.n_positions,
                 warmup_steps = args.warmup_steps,
                 update_freq = args.update_freq,
                 optimizer = args.optimizer,
                 dropout = args.dropout,
                 epsilon = args.epsilon,
                 target_update_freq = args.target_update_freq,
                 batch_size = args.batch_size,
                 sample_rate = args.sample_rate,
                 initial_sample = args.initial_sample,
                 temperature = args.temperature,
                 device = device).to(device)
    
    target_network.load_state_dict(learner.state_dict())
    target_network.eval()

    if args.demo:
        wandb_init(args.function_type+"_"+args.NERF_scene+"_max_norm=1.0_demo_"+args.demo_policy)
    else:
        wandb_init(args.function_type+"_"+args.NERF_scene+"_max_norm=1.0")
        # wandb_init(args.function_type+"_"+args.yahpo_scenario)
    losses = [] # training record
    sass = []
    q_values = []
    iterations = 0
    demo = args.demo
    for e in range(args.tuning_episode):
        for ep in range(10):
            seed=args.seed+e*10+ep
            env.reset(seed=seed)
            env.history['info'] = str(args)
            env.history['demo'] = demo
            lss = [] # training record
            sas = [] # training record

            args.domain_size = np.shape(env.X)[0]
            learner.domain_size = args.domain_size
            target_network.domain_size = args.domain_size
            # initial sample
            X = Environment.function_preprocessing.domain(args.function_type, args.domain_size, seed, args.domain_dim) #here
            action = random.randint(0,args.domain_size-1)
            y_star, reward, regret = env.step(X[action])
            gp, _ = env.fit_gp(0)
            state_actions = Environment.env.construct_state_action_pair(X, gp, y_star, 0) #here
            # gp = Environment.env.GaussianProcess(np.array(env.history["x"]), np.array(env.history["y_observed"]), env.kernel, env.kernel_ls, env.f_num, dim = args.domain_dim)
            # if args.online_ls:
            #     gp.fit(x = np.array(env.history["x"]), y = np.array(env.history["y_observed"]))
            # state_actions = gp.construct_state_action_pair(X, y_star, 0/args.T)
        

            # record transition
            learner.trajectory.append({"state_action": np.array([0.0]*(3*args.f_num+1)), "reward": float(reward), "next_state_actions": state_actions})

            # training record
            # lss.append(torch.cat([gp.GP[i].covar_module.lengthscale.detach().cpu() for i in range(args.f_num)], dim=1).numpy())
            sas.append(state_actions)
            if demo:
                gp_demo, _ = env.fit_gp(0)
                pred = gp_demo.posterior(torch.tensor(X, dtype=torch.double)).mean # gp(torch.tensor(X)).mean.T
                partitioning = FastNondominatedPartitioning(
                    ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                    Y=pred)
                
                standard_bounds = torch.zeros(2, env.domain_dim)
                standard_bounds[1] = 1
                
                if args.demo_policy == "qEHVI":
                    demo_policy = qExpectedHypervolumeImprovement(
                        model=gp_demo,
                        ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                        partitioning=partitioning,
                    )
                elif args.demo_policy == "qNEHVI":
                    demo_policy = qNoisyExpectedHypervolumeImprovement(
                        model=gp_demo,
                        ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                        X_baseline=torch.tensor(env.history['x'], dtype=torch.double),
                        prune_baseline=True,
                    )
                elif args.demo_policy == "EHI":
                    demo_policy = ExpectedHypervolumeImprovement(domain_size=env.domain_size, f_num=env.f_num, min_function_values=env.min_function_values)
                
            # training iterations
            for t in tqdm.tqdm(range(1, args.T)):
                
                # select action
                learner.eval()
                if demo:
                    gp_demo, _ = env.fit_gp(t)
                    pred = gp_demo.posterior(torch.tensor(X, dtype=torch.double)).mean # gp(torch.tensor(X)).mean.T
                    partitioning = FastNondominatedPartitioning(
                        ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                        Y=pred)
                    
                    standard_bounds = torch.zeros(2, env.domain_dim)
                    standard_bounds[1] = 1
                    
                    if args.demo_policy == "qEHVI":
                        demo_policy = qExpectedHypervolumeImprovement(
                            model=gp_demo,
                            ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                            partitioning=partitioning,
                        )
                    elif args.demo_policy == "qNEHVI":
                        demo_policy = qNoisyExpectedHypervolumeImprovement(
                            model=gp_demo,
                            ref_point=torch.tensor(env.min_function_values, dtype=torch.double),
                            X_baseline=torch.tensor(env.history['x'], dtype=torch.double),
                            prune_baseline=True,
                        )
                    elif args.demo_policy == "EHI":
                        demo_policy = ExpectedHypervolumeImprovement(domain_size=env.domain_size, f_num=env.f_num, min_function_values=env.min_function_values)
                    
                    candidates, _ = optimize_acqf(
                        acq_function=demo_policy,
                        bounds=standard_bounds.double(),
                        q=1,
                        num_restarts=1,
                        raw_samples=1,
                        options={"batch_limit": 1, "maxiter": 10},
                        sequential=True,
                    )
                    x_best = candidates.detach().numpy()[0]
                else:
                    action = learner.select_action(state_actions, target_network)
                    x_best = X[action]
                learner.train()

                # env update
                y_star, reward, regret = env.step(x_best)

                
                X = Environment.function_preprocessing.domain(args.function_type, args.domain_size, seed+t, d = args.domain_dim)
                gp, _ = env.fit_gp(t)
                next_state_actions = Environment.env.construct_state_action_pair(X, gp, y_star, t/args.T) 
                # gp = Environment.env.GaussianProcess(np.array(env.history["x"]), np.array(env.history["y_observed"]), env.kernel, env.kernel_ls, env.f_num, dim = args.domain_dim)
                # learn ls for GP
                # if args.online_ls and t % args.ls_learned_freq == 0:
                    # env.kernel_ls = gp.fit(x = np.array(env.history["x"]), y = np.array(env.history["y_observed"]))
                    
                # next_state_actions = gp.construct_state_action_pair(X, y_star, t/args.T) 

                # record transition
                learner.trajectory.append({"state_action": state_actions[action], "reward": float(reward), "next_state_actions": next_state_actions})
                
                # update current state
                state_actions = next_state_actions

                # training record
                # lss.append(torch.cat([gp.GP[i].covar_module.lengthscale.detach().cpu() for i in range(args.f_num)], dim=1).numpy())
                sas.append(state_actions)

            # record final transition
            # learner.trajectory.append({"state_action": state_actions[action], "reward": 0.0, "next_state_actions": next_state_actions})
            
            # learner update
            learner.append_sample(learner.trajectory)
            learner.reset_trajectory()

        sass.append(sas)
        if not demo:
            q_values.append(learner.previous_q_values)

    # target_network.domain_size = 77 # real
    i = 0
    cnt = 0
    prev_loss = None
    best_loss = float('inf')
    threshold = 0.0001
    grad_norms = []
    while i < args.update_step:
        i+=1
        loss, grad_norm = learner.update(target_network)
        wandb.log({'loss': loss, 'grad_norm': grad_norm}, step=i)
        losses.append(loss)
        grad_norms.append(grad_norm)

        if i % 500 == 0:
            if args.demo:
                torch.save(learner.state_dict(), "./BOFormer_models/model{}/self_finetune_weight/{}_{}_{}_finetuning_v3_per_10_ep_{}_max_norm_1_demo_{}_{}_{}_steps.pth".format(record_index, args.model_episode, args.function_type, args.NERF_scene, args.tuning_episode, args.demo, args.demo_policy, i))
            else:
                torch.save(learner.state_dict(), "./BOFormer_models/model{}/self_finetune_weight/{}_{}_{}_finetuning_v3_per_10_ep_{}_max_norm_1_demo_{}_{}_steps.pth".format(record_index, args.model_episode, args.function_type, args.NERF_scene, args.tuning_episode, args.demo, i))

    env.history["loss"] = losses
    env.history["grad_norm"] = grad_norms
    env.history["sa"] = sass

    if demo:
        env.history["q_values"] = None
        print('EP:{} | Loss: {:.3f} | P: {} | R: {:.3f}'.format(e + args.model_episode, np.mean(losses), args.demo_policy, regret))
    else:
        env.history["q_values"] = q_values
        print('E:{} | tuning EP:{} | Loss: {:.3f} | Q: {:.3f} | R: {:.3f}'.format(e, ep, np.mean(losses), learner.previous_q_values[:,-1,:].item(), regret))

    # save regrets
    if args.demo:
        filename = '{}_{}_{}_max_norm_1_demo_{}_{}_{}_steps.pkl'.format(args.model_episode, args.function_type, args.NERF_scene, args.demo, args.demo_policy, args.update_step)
    else:
        filename = '{}_{}_{}_max_norm_1_demo_{}_{}_steps.pkl'.format(args.model_episode, args.function_type, args.NERF_scene, args.demo, args.update_step)
    os.makedirs(f"records/record{record_index}/self_finetune", exist_ok=True)
    with open(os.path.join(f"records/record{record_index}/self_finetune", filename), 'wb') as f:
        pickle.dump(env.history, f)
    # update target network
    # target_network.load_state_dict(learner.state_dict())
    # target_network.eval()
    if args.demo:
        torch.save(learner.state_dict(), "./BOFormer_models/model{}/self_finetune_weight/{}_{}_{}_finetuning_v3_per_10_ep_{}_max_norm_1_demo_{}_{}.pth".format(record_index, args.model_episode, args.function_type, args.NERF_scene, args.tuning_episode, args.demo, args.demo_policy))
    else:
        torch.save(learner.state_dict(), "./BOFormer_models/model{}/self_finetune_weight/{}_{}_{}_finetuning_v3_per_10_ep_{}_max_norm_1_demo_{}.pth".format(record_index, args.model_episode, args.function_type, args.NERF_scene, args.tuning_episode, args.demo))
