import os
import time
import re
import random
import tqdm
import wandb
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling

from embodied_cd.common.print_utils import *
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.common.llm_utils import OpenAILLM
from embodied_cd.trl.algos.pipe import SentenceSimilarityPipeline
from embodied_cd.trl.algos.core import FixedKLController, AdaptiveKLController
from embodied_cd.trl.algos.core import (
    custom_collate,
    logprobs_from_logits, 
    entropy_from_logits, 
    clip_by_value, 
    whiten, 
    flatten_dict, 
    stack_dicts, 
    stats_to_np, 
    WANDB_PADDING,
)


class ThinkTrainer:
    """
    The PPO Trainer
    """

    default_params = {
        "warmup_epochs": 10,
        "warmup_kl_coef": 0.1,
        "warmup_early_stopping_threshold": 1.1,
        "total_epochs": 100,
        "ppo_epochs": 1,
        "lr": 2.82e-6,
        "batch_size": 4,
        "num_few_shot_example": 3,
        "gamma": 0.99,
        "lam": 0.95, # lambda for advantage calculation
        "cliprange": .0,
        "cliprange_value": .1,
        "vf_coef": .1,
        "adapt_kl_ctrl": True,
        "init_kl_coef": 0.2,
        "kl_target": 6,
        "kl_horizon": 10000,
        "max_think_token": 80,
        "alpha": 1.2, #  for score regularization
        "self_correct": True,
    }

    gen0_params = {
        "do_sample": False,
        "top_k": None,
        "top_p": None,
        "temperature": None,
        "num_beams": 1, 
        "repetition_penalty": 1.0,
    }

    gen1_params = {
        "do_sample": True,
        "top_k": 0,
        "top_p": 0.4,
        "temperature": 0.2,
    }

    gen2_params = {
        "do_sample": True,
        "top_k": 0,
        "top_p": 0.6,
        "temperature": 0.8,
    }

    def __init__(
        self, 
        env_name, 
        model, 
        ref_model, 
        tokenizer, 
        ref_tokenizer, 
        rew_model, 
        rew_tokenizer, 
        dataset, 
        output_dir: str = None, 
        **params
    ):
        self.env_name = env_name
        self.output_dir = output_dir

        self.params = self.default_params
        self.params.update(params)
        
        self.model = model
        self.tokenizer = tokenizer
        self.ref_model = ref_model if ref_model is not None else model
        self.ref_tokenizer = ref_tokenizer if ref_tokenizer is not None else tokenizer
        self.device = model.device

        # self.data_collator = DataCollatorWithPadding(tokenizer, padding=True)
        self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        self.dataset = dataset
        self.dataset_len = len(dataset)

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.params['batch_size'],
            shuffle=True,
            collate_fn=custom_collate,
            drop_last=True,
        )
        
        self.warmup_optimizer = torch.optim.Adam(model.parameters(), lr=self.params['lr'])
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.params['lr'])

        if self.params['adapt_kl_ctrl']:
            self.kl_ctl = AdaptiveKLController(
                self.params['init_kl_coef'], self.params['kl_target'], self.params['kl_horizon'])
        else:
            self.kl_ctl = FixedKLController(self.params['init_kl_coef'])

        # reward models
        self.score_pipe = SentenceSimilarityPipeline(rew_model, rew_tokenizer)

        # set init response
        print_warn("Setting Init Response")
        self.set_init_response()

    def set_correct_few_shot_example(self):
        # setting few_shot_example for ... corrective response
        self.correct_few_shot_example = ""
        for idx, data in enumerate(self.dataset):
            if idx == self.params['num_few_shot_example']: break
            instruction, state, history, think = data["instruction"], data["state"], data["history"], data["think"]
            state = PromptTemplate.preprocess(state)

            query_id = torch.tensor(data["query_ids"]).to(self.device)
            response_id = self.model.generate(
                query_id, 
                **self.gen0_params,
                max_new_tokens=self.params['max_think_token'],
                pad_token_id=self.tokenizer.eos_token_id, # this makes limited output
            )
            response_text = self.tokenizer.decode(
                response_id.squeeze()[len(query_id[0]):], skip_special_tokens=True)

            self.correct_few_shot_example += f"\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {response_text}\nCorrected Think: {think}\n"

    def set_init_response(self):
        self.model.set_adapter("adapter_1")
        
        self.init_response_dict = {}
        for batch in self.dataloader:
            query_texts, gt_response_texts = batch['query_texts'], batch['response_texts']
            query_ids, gt_response_ids = batch['query_ids'], batch['response_ids']

            ### 1. Get initial response from the model
            query_tensors, response_tensors, response_texts, scores, indexes = \
                self.generate_init_response(query_ids, gt_response_texts, batch)

            for i in range(self.params['batch_size']):
                print_pass(gt_response_texts[i])
                print_warn(batch['indexes'][i], response_texts[i])
                self.init_response_dict[batch['indexes'][i]] = {
                    'query_text' : query_texts[i],
                    'query_tensor': query_tensors[i], 
                    'response_text': response_texts[i], 
                    'response_tensor': response_tensors[i], 
                    'score': scores[i],
                    'index': indexes[i],
                }

    def warmup(self):
        self.model.set_adapter("adapter_2")
        #self.set_correct_few_shot_example()

        for epoch in tqdm.tqdm(range(1, self.params['warmup_epochs']+1), desc="epoch"):
            all_stats = []
            for batch in self.dataloader:
                query_texts, gt_response_texts = batch['query_texts'], batch['response_texts']
                query_ids, gt_response_ids = batch['query_ids'], batch['response_ids']

                ### 1. Get initial response from the model
                #_, _, response_texts, _, _ = \
                #    self.generate_init_response(query_ids, gt_response_texts, batch)
                """ 
                print_error("===== Checking Texts =====")
                print(query_texts[-1])
                print_pass(f"Ground Truth: {gt_response_texts[-1]}")
                print_warn(f"Initial Response: {response_texts[-1]}")
                print_error("===== Checking End =====")
                """
                
                ###########################################
                idxs = list(range(self.params['batch_size']))
                random.shuffle(idxs)
                batch_loss = 0.
                for i in range(self.params['batch_size']):
                    idx = idxs[i]
                    instruction, state, history, think = \
                        batch['instructions'][idx], batch['states'][idx], batch['histories'][idx], batch['thinks'][idx]
                    state = PromptTemplate.preprocess(state)

                    ###  Reference Model Inference
                    query = [
                        #{"role": "user", "content": query_texts[idx]},
                        #{"role": "user", "content": f"{PromptTemplate.init_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}"},
                        #{"role": "assistant", "content": self.init_response_dict[batch['indexes'][idx]]['response_text']},
                        #{"role": "user", "content": PromptTemplate.correct_think_prompt}, 
                        #{"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\n{self.correct_few_shot_example}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {response_texts[idx]}"}
                        {"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {self.init_response_dict[batch['indexes'][idx]]['response_text']}"}
                    ]

                    query_tensor = self.tokenizer.apply_chat_template(
                        query, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.device)
                    
                    self.ref_model.set_adapter("adapter_1")
                    with torch.no_grad():
                        ref_model_output = self.ref_model(query_tensor)
                    self.ref_model.set_adapter("adapter_2")
                    ref_logprobs = logprobs_from_logits(ref_model_output.logits[:,:-1,:], query_tensor[:,1:])

                    ### Correction Model Inference
                    cquery = [
                        #{"role": "user", "content": query_texts[idx]},
                        #{"role": "user", "content": f"{PromptTemplate.init_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}"},
                        #{"role": "assistant", "content": self.init_response_dict[batch['indexes'][idx]]['response_text']},
                        #{"role": "user", "content": PromptTemplate.correct_think_prompt}, 
                        #{"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\n{self.correct_few_shot_example}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {response_texts[idx]}"},
                        {"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {self.init_response_dict[batch['indexes'][idx]]['response_text']}"},
                        {"role": "assistant", "content": gt_response_texts[idx]},
                    ]
                    cquery_tensor = self.tokenizer.apply_chat_template(
                        cquery, tokenize=True, return_tensors="pt").to(self.device)
                    cquery_label = cquery_tensor.clone()
                    # leave only the last assistant answer and mask the rest
                    cquery_label[:, :len(query_tensor[0])] = -100 
                    
                    model_output = self.model(cquery_tensor, labels=cquery_label)
                    logprobs = logprobs_from_logits(model_output.logits[:,:-1,:], cquery_tensor[:,1:])
                    
                    ### Calculate the KL Divergence with Reference Model
                    kl = torch.abs(torch.mean(logprobs[:,:ref_logprobs.shape[-1]] - ref_logprobs))

                    ### Calculate BC loss with KL penalty
                    loss = model_output.loss + self.params['warmup_kl_coef'] * kl
                    batch_loss += loss

                    """ Append statistics """
                    all_stats.append({
                        "warmup_loss": loss,
                        "warmup_lm_loss": model_output.loss,
                        "warmup_kl_loss": kl,
                    }) 
                self.warmup_optimizer.zero_grad()
                batch_loss.backward()
                self.warmup_optimizer.step()
                ####################################### 
            
            train_stats = stack_dicts(all_stats)
            stats = {}
            for k, v in train_stats.items():
                stats[f'{k}'] = torch.mean(v, axis=0)
            stats = stats_to_np(stats)
            print_warn(f"Epoch {epoch}: {stats}")            
            wandb.log(stats)
            
            ### Early Stopping
            if stats["warmup_lm_loss"] < self.params["warmup_early_stopping_threshold"]:
                print_warn("Early Breaking!!")
                break

    def train(self):
        self.model.set_adapter("adapter_3")
        #self.set_correct_few_shot_example()

        for epoch in tqdm.tqdm(range(1, self.params['total_epochs']+1), desc="epoch"): # iterate over epochs
            for batch in self.dataloader: # iterate over dataset
                query_texts, gt_response_texts = batch['query_texts'], batch['response_texts']
                query_ids, gt_response_ids = batch['query_ids'], batch['response_ids']

                ### 1. Get initial response from the model
                #query_tensors, response_tensors, response_texts, scores, indexes = \
                #    self.generate_init_response(query_ids, gt_response_texts, batch)
                query_tensors, response_tensors, response_texts, scores, indexes = [], [], [], [], []
                for i in range(self.params['ppo_epochs']):
                    query_tensors.append(self.init_response_dict[batch['indexes'][i]]['query_tensor'])
                    query_tensors.append(self.init_response_dict[batch['indexes'][i]]['query_tensor'])
                    response_tensors.append(self.init_response_dict[batch['indexes'][i]]['response_tensor'])
                    response_tensors.append(self.init_response_dict[batch['indexes'][i]]['response_tensor'])
                    response_texts.append(self.init_response_dict[batch['indexes'][i]]['response_text'])
                    response_texts.append(self.init_response_dict[batch['indexes'][i]]['response_text'])
                    scores.append(self.init_response_dict[batch['indexes'][i]]['score'])
                    scores.append(self.init_response_dict[batch['indexes'][i]]['score'])
                    indexes.append(self.init_response_dict[batch['indexes'][i]]['index'])
                    indexes.append(self.init_response_dict[batch['indexes'][i]]['index'])
                
                ### 2. Get corrective response from the model
                if self.params['self_correct']:
                    self.model.set_adapter("adapter_2")
                    cquery_tensors, cresponse_tensors, cresponse_texts, cscores, cindexes = \
                        self.generate_correct_response(query_texts, response_texts, gt_response_texts, batch)
                    self.model.set_adapter("adapter_3")
                
                #### 3. Calculate final score with bonus 
                fscores = scores
                if self.params['self_correct']:
                    query_tensors, response_tensors = cquery_tensors, cresponse_tensors
                    fscores = []
                    for s, cs in zip(scores, cscores):
                        fscores.append(cs + 1.5 * (cs - s))
                        #fscores.append(cs - s) # only use the bonus term
                
                print_error("===== Checking Texts =====")
                print_pass(f"Ground Truth: {gt_response_texts[0]}")
                print_warn(f"Initial Response: {response_texts[0]}")
                print(f"Score: {scores[0]}")
                if self.params['self_correct']:
                    print_check(f"Corrected Response: {cresponse_texts[0]}")
                    print(f"Score: {cscores[0]}")
                print(f"Final scores {fscores[0]}")
                print()
                print_pass(f"Ground Truth: {gt_response_texts[0]}")
                print_warn(f"Initial Response: {response_texts[1]}")
                print(f"Score: {scores[1]}")
                if self.params['self_correct']:
                    print_check(f"Corrected Response: {cresponse_texts[1]}")
                    print(f"Score: {cscores[1]}")
                print(f"Final scores {fscores[1]}")
                print_error("===== Checking End =====")

                ### 4. Calculate log probability of the corrective response
                logprobs, ref_logprobs, values = \
                    self.forward_pass(query_tensors, response_tensors)

                ### 4. Calculate rewards
                rewards, non_score_rewards = self.compute_rewards(fscores, cindexes, logprobs, ref_logprobs)

                all_stats = []
                idxs = list(range(self.params['batch_size']))
                for _ in range(self.params['ppo_epochs']):
                    ####################################### 
                    random.shuffle(idxs)
                    batch_loss = 0. 
                    for i in range(self.params['batch_size']):
                        idx = idxs[i]
                        loss, train_stats = self.train_minibatch(
                            logprobs[idx].unsqueeze(0), values[idx].unsqueeze(0), rewards[idx].unsqueeze(0), 
                            query_tensors[idx].unsqueeze(0), response_tensors[idx].unsqueeze(0),
                            torch.cat([query_tensors[idx], response_tensors[idx]]).unsqueeze(0)
                        )
                        batch_loss += loss
                        all_stats.append(train_stats)
                    self.optimizer.zero_grad()
                    batch_loss.backward()
                    self.optimizer.step()
                    ####################################### 

                train_stats = stack_dicts(all_stats)
                # reshape advantages/ratios such that they are not averaged.
                train_stats['pi/advantages'] = torch.flatten(train_stats['pi/advantages']).unsqueeze(0)
                train_stats['pi/advantages'] = torch.nan_to_num(train_stats['pi/advantages'], WANDB_PADDING)
                train_stats['pi/ratio'] = torch.flatten(train_stats['pi/ratio']).unsqueeze(0)
                stats = self.record_step_stats(scores=fscores, logprobs=logprobs, ref_logprobs=ref_logprobs,
                                       non_score_reward=non_score_rewards, train_stats=train_stats,
                                       kl_coef=self.kl_ctl.value)
                stats = stats_to_np(stats)
                self.kl_ctl.update(stats['objective/kl'], self.params['batch_size'])
                wandb.log(stats)

            if epoch % 50 == 0:
                output_dir = os.path.join(self.output_dir, f"checkpoint_{epoch}")
                self.save_pretrained(output_dir)

    def generate_init_response(self, query_ids, gt_response_texts, batch):
        query_tensors, response_texts, response_tensors = [], [], []
        scores, indexes = [], [] 
        for i, query_id in enumerate(query_ids):
            query_id = torch.tensor(query_id).to(self.device)
            # we append twice
            query_tensors.append(query_id.squeeze())
            
            response_id = self.model.generate(
                query_id, 
                **self.gen0_params,
                max_new_tokens=self.params['max_think_token'],
                pad_token_id=self.tokenizer.eos_token_id, # this makes limited output
            )

            response_tensors.append(response_id.squeeze()[len(query_id[0]):])
            response_text = self.tokenizer.decode(
                response_id.squeeze()[len(query_id[0]):], skip_special_tokens=True)
            response_texts.append(response_text)
            
            # Calculate Scores 
            score, index = self.compute_score(
                response_text, gt_response_texts[i], batch['states'][i], batch['histories'][i])
            scores.append(score)
            indexes.append(index)
        return query_tensors, response_tensors, response_texts, scores, indexes

    def generate_correct_response(self, query_texts, response_texts, gt_response_texts, batch):
        cquery_tensors, cresponse_texts, cresponse_tensors = [], [], []
        cscores, cindexes = [], []
        for i, (query_text, response_text) in enumerate(zip(query_texts, response_texts)):
            # 1. add random samples
            instruction, state, history = batch['instructions'][i], batch['states'][i], batch['histories'][i]
            state = PromptTemplate.preprocess(state)
            cquery = [
                #{"role": "user", "content": query_text},
                #{"role": "user", "content": f"{PromptTemplate.init_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}"},
                #{"role": "assistant", "content": response_text},
                #{"role": "user", "content": PromptTemplate.correct_think_prompt}, 
                #{"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\n{self.correct_few_shot_example}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {response_text}"}
                {"role": "user", "content": f"{PromptTemplate.correct_think_prompt}\nInstruction: {instruction}\nState: {state}\nPrevious Actions: {history}\nThink: {response_text}"}
            ]
            cquery_tensor = self.tokenizer.apply_chat_template(
                cquery, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.device)
            cquery_tensors.append(cquery_tensor.squeeze())
            
            cresponse_id = self.model.generate(
                cquery_tensor,
                **self.gen2_params,
                max_new_tokens=self.params['max_think_token'],
                pad_token_id=self.tokenizer.eos_token_id, # this makes limited output
            )
            cresponse_tensors.append(cresponse_id.squeeze()[len(cquery_tensor[0]):])
            cresponse_text = self.tokenizer.decode(
                cresponse_id.squeeze()[len(cquery_tensor[0]):], skip_special_tokens=True)
            cresponse_texts.append(cresponse_text)

            # Calculate Scores 
            score, index = self.compute_score(
                cresponse_text, gt_response_texts[i], batch['states'][i], batch['histories'][i])
            cscores.append(score)
            cindexes.append(index)

            # 2. add positive samples
            cquery_tensors.append(cquery_tensor.squeeze())
            cresponse_text = batch['thinks_copy'][i]
            cresponse_id = self.tokenizer.encode(cresponse_text, return_tensors="pt").to(self.device)
            cresponse_tensors.append(cresponse_id.squeeze())
            cresponse_texts.append(cresponse_text)

            score, index = self.compute_score(
                cresponse_text, gt_response_texts[i], batch['states'][i], batch['histories'][i])
            cscores.append(score)
            cindexes.append(index)

        return cquery_tensors, cresponse_tensors, cresponse_texts, cscores, cindexes

    def forward_pass(self, query_ids, response_ids):
        """Calcuate model outupts in batch"""
        all_logprobs, all_ref_logprobs, all_values = [], [], []

        input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_ids, response_ids)])["input_ids"]

        with torch.no_grad():
            model_output = self.model(input_ids)
        logprobs = logprobs_from_logits(model_output.logits[:,:-1,:], input_ids[:,1:])

        self.ref_model.set_adapter("adapter_2")
        with torch.no_grad():
            ref_model_output = self.ref_model(input_ids)
        self.ref_model.set_adapter("adapter_3")
        ref_logprobs = logprobs_from_logits(ref_model_output.logits[:,:-1,:], input_ids[:,1:])
        
        for i in range(self.params['batch_size']):
            start = len(query_ids[i]) - 1
            end = len(query_ids[i]) + len(response_ids[i]) - 1
            all_logprobs.append(logprobs[i, start:end])
            all_ref_logprobs.append(ref_logprobs[i, start:end])
            all_values.append(model_output.values[i, start-1:end-1])
        return all_logprobs, all_ref_logprobs, all_values
    
    def compute_score(self, response, gt_response, state, history):
        response_list = response.split('. ')
        gt_response_list = gt_response.split('.')

        score = [0., 0., 0., 0.]
        index = [0, 0, 0, 0]
        str_char, str_history = PromptTemplate.get_state_history_str(self.env_name, state, history)
        total_res = ""
        for i, (res, gres) in enumerate(zip(response_list, gt_response_list)):
            if i == 0: 
                # score 1: status of the agent 
                score[0] = self.score_pipe(res, gres)
                #print("1", res, self.score_pipe(res, str_char))
            if i == 1:
                # score 2: summary of key observation
                score[1] = self.score_pipe(res, gres)
            if i == 2:
                # score 3: summary of action history
                score[2] = self.score_pipe(res, gres)
                #print("3", res, self.score_pipe(res, str_history))
            if i == 3:
                # score 4: next action prediction
                score[3] = self.score_pipe(res, gres)
            if i == 3: break

            # for indexing
            if res[-1] != '.': res += '.'
            if i != 0: res = ' ' + res
            token = self.tokenizer.encode(res, return_tensors="np")[0]
            if "<|begin_of_text|>" in self.tokenizer.all_special_tokens and i > 0:
                index[i] = len(token) - 1
            else:
                index[i] = len(token)
            total_res += res

        score, index = np.array(score), np.array(index)
        # assertion
        token = self.tokenizer.encode(total_res, return_tensors="np")[0]
        assert len(token) == np.sum(index), f"{len(token)} != {np.sum(index)} is not same!\n{response}"
        return np.array(score), index

    def compute_rewards(self, scores, indexes, logprobs, ref_logprobs):
        """Compute per token reward from scores and KL-penalty."""
        rewards, non_score_rewards = [], []
        for score, index, logprob, ref_logprob in zip(scores, indexes, logprobs, ref_logprobs):
            kl = logprob - ref_logprob
            non_score_reward = -self.kl_ctl.value * kl
            non_score_rewards.append(non_score_reward)
            reward = non_score_reward.clone()
            # # # index sum
            index_sum = 0 
            for _s, _i in zip(score, index):
                if _i != 0:
                    index_sum += _i
                    if index_sum < self.params['max_think_token']:
                        try:
                            reward[index_sum-1] += _s
                        except:
                            print(score, index, index_sum, len(reward))
            # # # 
            # reward[-1] += score # add on the final reward
            rewards.append(reward)
        return rewards, non_score_reward
    
    def train_minibatch(self, old_logprobs, values, rewards, query, response, model_input):
        """Calcuate policy and value losses."""
        lastgaelam = 0
        advantages_reversed = []
        gen_len = response.shape[1]

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = rewards[:, t] + self.params['gamma'] * nextvalues - values[:, t]
            lastgaelam = delta + self.params['gamma'] * self.params['lam'] * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

        returns = advantages + values
        advantages = whiten(advantages)
        advantages = advantages.detach()
        
        model_output = self.model(model_input, labels=model_input, num_logits_to_keep=gen_len)
        lm_loss, logits, vpred = model_output.loss, model_output.logits, model_output.values
        logprob = logprobs_from_logits(logits[:,:-1,:], model_input[:, 1:])        

        # only the generation part of the values/logprogs is needed
        logprob, vpred = logprob[:, -gen_len:], vpred[:, -gen_len-1: -1]
        vpredclipped = clip_by_value(vpred, values - self.params['cliprange_value'], values + self.params['cliprange_value'])

        # value loss
        vf_losses1 = (vpred - returns)**2
        vf_losses2 = (vpredclipped - returns)**2
        vf_loss = .5 * torch.mean(torch.max(vf_losses1, vf_losses2))
        vf_clipfrac = torch.mean(torch.gt(vf_losses2, vf_losses1).double())
        
        ratio = torch.exp(logprob - old_logprobs)
        
        # policy loss
        pg_losses = -advantages * ratio
        pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.params['cliprange'], 1.0 + self.params['cliprange'])
        pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))
        pg_clipfrac = torch.mean(torch.gt(pg_losses2, pg_losses).double())

        loss = vf_loss #pg_loss + self.params['vf_coef'] * vf_loss
        
        entropy = torch.mean(entropy_from_logits(logits))
        approxkl = .5 * torch.mean((logprob - old_logprobs)**2)
        policykl = torch.mean(logprob - old_logprobs)
        return_mean, return_var = torch.mean(returns), torch.var(returns)
        value_mean, value_var = torch.mean(values), torch.var(values)

        stats = dict(
            loss=dict(policy=pg_loss, value=vf_loss, bc=lm_loss, total=loss),
            pi=dict(entropy=entropy, approxkl=approxkl, policykl=policykl, clipfrac=pg_clipfrac,
                        advantages=advantages, advantages_mean=torch.mean(advantages), ratio=ratio),
            returns=dict(mean=return_mean, var=return_var),
            vf=dict(vpred=torch.mean(vpred), error=torch.mean((vpred - returns)**2),
                    clipfrac=vf_clipfrac, mean=value_mean, var=value_var),
        )
        return loss, flatten_dict(stats)

    def record_step_stats(self, kl_coef, **data):
        """Record training step statistics."""
        kl_list = [logprobs-ref_logprobs for logprobs, ref_logprobs in zip(data['logprobs'], data['ref_logprobs'])]
        mean_kl = torch.mean(torch.stack([torch.sum(kl) for kl in kl_list]))
        mean_entropy = torch.mean(torch.stack([torch.sum(-log_probs) for log_probs in data['logprobs']]))
        mean_non_score_reward =torch.mean(torch.stack([torch.sum(non_score_reward) for non_score_reward in data['non_score_reward']]))
        stats = {
            'objective/kl': mean_kl,
            'objective/kl_dist': kl_list,
            'objective/logprobs': data['logprobs'],
            'objective/ref_logprobs': data['ref_logprobs'],
            'objective/kl_coef': kl_coef,
            'objective/entropy': mean_entropy,
            'mean_non_score_reward': mean_non_score_reward,
        }

        for k, v in data['train_stats'].items():
            stats[f'{k}'] = torch.mean(v, axis=0)
        stats['vf/var_explained'] = 1 - stats['vf/error'] / stats['returns/var']
        return stats

    def save_pretrained(self, output_dir: str):
        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        torch.save(self.model.vl_head.state_dict(), output_dir+'/vl_head.pth')


"""
    def compute_llm_score(self, instruction, state, action, history, think):
        query_spatial = {"instruction": instruction, "state": state, "think": think}
        spatial = self.llm_spatial.invoke(query_spatial).strip()
        query_temporal = {"instruction": instruction, "history": history, "action": action, "think": think}
        temporal = self.llm_temporal.invoke(query_temporal).strip()
        
        pattern = r"[-+]?\d*\.\d+|\d+"
        try:
            spatial_match = re.findall(pattern, spatial)
            spatial_score = [float(num) for num in spatial_match][0]
        except:
            print_warn(f"[Spatial Scoring]: Cannot parse {spatial}!")
            sptial_score = 0.

        try:
            temporal_match = re.findall(pattern, temporal)
            temporal_score = [float(num) for num in temporal_match][0]
        except:
            print_warn(f"[Temporal Scoring]: Cannot parse {temporal}!")
            temporal_score = 0.
        return spatial_score, temporal_score
        spatial_template = "Instruction: {instruction}\nCurrent State: {state}\nThink: {think}\nGiven the instruction and current state, evalulate how accurately a 'think' statement captures the spatial information. The think statement should include:\n1. the agent's current physical location\n2. relative location of objects related to accomplish the instruction.\nScore the spatial accuracy of the think statement on this scale:\n0: Does not capture any spatial information\n0.5: Partially captures spatial information\n1: Fully captures spatial information\nPlease only provide your score (0, 0.5, or 1)."
        self.llm_spatial= OpenAILLM(
            "gpt-4o-mini",
            temperature=0.0,
            top_p=0.9,
            template=spatial_template,
        )
        temporal_template = "Instruction: {instruction}\nPrevious Actions: {history}\nCurrent Optimal Action: {action}\nThink: {think}\nGiven the instruction and previous actions, evaluate how accurately a 'think' statement captures temporal information. The think statement should include:\n1. Summary of past actions: an accurate summary of previously completed steps\n2. Future planning: A logical decomposition of remaining instruction to be completed considering the current optimal action.\nScore the temporal accuracy of the think statement on this scale:\n0: Does not accurately capture action history or provide future planning\n0.5: Partially capture either action history or future planning, but not bot effectively\n1: Fully capture both past actions and provide clear future task decomposition.\nPlease only provide your score (0, 0.5, or 1)."
        self.llm_temporal = OpenAILLM(
            "gpt-4o-mini",
            temperature=0.0,
            top_p=0.9,
            template=temporal_template,
        )
"""
