import argparse
import os
import random
import time
from random import choices

import numpy as np
import pandas as pd
import torch
import transformers
import wandb
# from google.colab import userdata
from datasets import Dataset, load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig, pipeline, set_seed)
from trl.trainer import PPOTrainer, PPOConfig
from trl.models import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from tot.solver_utils import solvercheck_propose_prompt_out
from tot.tasks.game24 import get_current_numbers

tqdm.pandas()

np.random.seed(99)

os.environ["WANDB_DISABLED"] = "false"

parser = argparse.ArgumentParser()
parser.add_argument('--feedback_mode', type=str, required=False, choices=['binary', 'cert'], default='cert')
parser.add_argument('--penalty_value', type=float, required=False, default=0.0)
parser.add_argument('--rew_value', type=float, required=False, default=1.0)
parser.add_argument('--eos_penalty_value', type=float, required=False, default=0.0)
parser.add_argument('--eos_rew_value', type=float, required=False, default=1.0)
args = parser.parse_args()
print(args)

run = wandb.init()

wandb.config.update(args)

# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

model_id = "google/gemma-2b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             peft_config=lora_config,
                                             token=os.environ['HF_TOKEN'])


model.gradient_checkpointing_enable()

config = PPOConfig(
    batch_size=4, mini_batch_size=4, steps=10000, learning_rate=1.41e-5, remove_unused_columns=False, log_with="wandb"
) # remove_unused_columns False because we want to preserve the text version of query in the dataset

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])


df = pd.read_csv('input_prompt_csv_demo.csv')

df_dict = df.to_dict(orient='list')
print(df_dict.keys())

dataset = Dataset.from_dict(df_dict)
print(dataset)

dataset = dataset.map(
    lambda x: {"input_ids": tokenizer.encode(" " + x["final_prompt"], return_tensors="pt")[0]},
    batched=False,
)

dataset.set_format("pytorch")

ppo_trainer = PPOTrainer(config, model, None, tokenizer, dataset=dataset, data_collator=collator)

generation_kwargs = {
    "min_length": -1,
    "do_sample": True,
    "temperature": 1,
    "max_new_tokens": 200,
    "return_prompt": False,
    # "top_k": 0.0, # no top-k sampling
    # "top_p": 1.0, # no nucleus sampling
}

for epoch in range(10):
    for batch in tqdm(ppo_trainer.dataloader):
        query_tensors = batch["input_ids"] # these are the tokenized list of tensors

        #### Get response from gemma
        response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, generate_ref_response=True, **generation_kwargs)
        batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
        batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors, skip_special_tokens=True)

        #### Compute score
        rewards = []
        for i, (x, y_second_to_last, decoded_response) in enumerate(zip(batch["input_x"], batch["input_y_second_to_last"], batch["response"])):
            decoded_response = decoded_response.strip()
            proposals = decoded_response.split('\n')
            input_to_prompt = get_current_numbers(y_second_to_last if y_second_to_last else x)

            corrected_proposals = []
            for proposal in proposals:
                print("Proposal: ", proposal)
                try:
                    proposal_correct = solvercheck_propose_prompt_out(second_to_last_line_left=input_to_prompt, last_line=proposal, current_numbers=get_current_numbers(proposal)) # False, None, or a string
                except Exception as e:
                    print(e)
                    proposal_correct = None
                print("Corrected proposal: ", proposal_correct)
                if proposal_correct is None: proposal_correct = proposal
                corrected_proposals.append(proposal_correct)
            decoded_responses_corrected = "\n".join(corrected_proposals) # back to original format - only used for error vector calc
            if decoded_responses_corrected.strip()!=decoded_response.strip():
                score_int = args.eos_penalty_value
            else:
                score_int = args.eos_rew_value
                print("decoded_response is correct! -> ", decoded_response)
            corrected_decoded_response_tok = tokenizer.encode(decoded_responses_corrected)
            reward_vector = [args.rew_value if  j < len(corrected_decoded_response_tok) and response_tensors[i][j] == corrected_decoded_response_tok[j] else args.penalty_value for j in range(len(response_tensors[i]))]
            reward_vector[0] = args.rew_value # ignore new line/<sos> mismatches
            reward_vector[-1] = score_int
            if args.feedback_mode == 'binary':
                rewards.append(torch.tensor(score_int, dtype=torch.float16))
            else:
                rewards.append(torch.FloatTensor(reward_vector))
        
        if args.feedback_mode == 'binary':
            rewards_to_train = rewards
            rewards_to_log = rewards
        else:
            rewards_to_train = rewards
            rewards_to_log = [r.mean() for r in rewards]

        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards_to_train)
        ppo_trainer.log_stats(stats, batch, rewards_to_log, columns_to_log=["input_x", "input_y_second_to_last", "final_prompt", "response", "ref_response"])

