import random
import json
import torch
from torch.utils.data import DataLoader
from accelerate.state import PartialState, DistributedType
from accelerate.utils import release_memory, InitProcessGroupKwargs
import datasets 
from datasets import Dataset
datasets.disable_progress_bar()
from datetime import timedelta
import os  
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  
os.environ["WANDB_LOG_MODEL"]="false"
from accelerate.state import DistributedType
import threading


from tqdm.auto import tqdm

from utils.loggers import loggers

from transformers import (
    AdamW,
    AutoModel,
    AutoModelForCausalLM, 
    AutoModelForSequenceClassification,
    AutoTokenizer, 
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
)
from utils.util import get_answer_start_from

from accelerate import Accelerator
import torch.nn as nn
import torch.distributed as dist
from llms.decoder import DecoderModel
from safetensors.torch import load_file
import time

torch.cuda.empty_cache()
torch.set_printoptions(threshold=10_000)

class CriticModel():
    '''
    This class implements the critic model used to guide the search process, like Roberta
    '''
    def __init__(self, config, accelerator):
        self.config = config
        self.accelerator = accelerator

        # init wandb
        if config["wandb_project"]:
            os.environ["WANDB_PROJECT"] = config["wandb_project"]
        if config["wandb_group"]:
            os.environ["WANDB_GROUP"] = config["wandb_group"]
        if config["wandb_run_name"]:
            os.environ["WANDB_RUN_NAME"] = config["wandb_run_name"]

        load_critic_model = self.config.get("load_critic_model", None)

        self.accelerator.print(f"Loading tokenizer from {load_critic_model}")
        self.tokenizer = AutoTokenizer.from_pretrained(load_critic_model, truncation_side='left')
        if self.tokenizer.eos_token is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

        self.accelerator.print(f"Loading critic model from {load_critic_model}")
        self.mode = self.config.get("critic_mode", "classification")
        if self.mode == "classification":
            self.model = AutoModelForSequenceClassification.from_pretrained(
                load_critic_model,
                trust_remote_code=True,
                num_labels=1, 
            )

            if self.tokenizer.eos_token_id is not None:
                self.model.config.pad_token_id = self.tokenizer.eos_token_id
            
            # disable reference compile for modernbert
            if hasattr(self.model.config, 'reference_compile'):
                self.model.config.reference_compile = False
        
            if config["gradient_checkpointing"]:
                self.model.gradient_checkpointing_enable()
            self.model.config.pretraining_tp = 1

        elif self.mode.startswith("generation"):
            self.model = DecoderModel(self.config["critic"], mode=self.mode)
            if os.path.exists(load_critic_model):
                print(f"Loading critic model....")
                state_dict = load_file(os.path.join(load_critic_model, 'model.safetensors'), device='cpu')
                missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
                if missing_keys:
                    print(f"Load Generation Model Missing keys: {missing_keys}")
                if unexpected_keys:
                    print(f"Load Generation Model Unexpected keys: {unexpected_keys}")

            self.model.base_model.config.pad_token_id = self.tokenizer.eos_token_id
        
            if config["gradient_checkpointing"]:
                self.model.base_model.gradient_checkpointing_enable()
            self.model.base_model.config.pretraining_tp = 1

        else:
            raise NotImplementedError(f"Mode {self.mode} not implemented")
        
        if self.tokenizer.pad_token is None:
            self.accelerator.print("Adding pad token to the tokenizer...")
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.model.resize_token_embeddings(len(self.tokenizer)) if self.model == 'classification' else self.model.base_model.resize_token_embeddings(len(self.tokenizer))

        self.answer_token = self.tokenizer.encode(" \nA:", return_tensors="pt", add_special_tokens=False)[0, 1:]

        self.model = self.accelerator.prepare(self.model)

        learning_rate = config.get("learning_rate", 5e-6)
        self.optimizer = AdamW(
                self.model.parameters(), 
                lr=learning_rate * self.accelerator.gradient_accumulation_steps,
                weight_decay=0.01,
            )
        warmup_steps = config.get("warmup_steps", 50)
        self.lr_scheduler = get_constant_schedule_with_warmup(
                self.optimizer, 
                num_warmup_steps=warmup_steps, 
        )
        
        self.accelerator.print(
            f"Distributed: {self.accelerator.distributed_type}, Mixed precision: {self.accelerator.mixed_precision}"
        )

        self.model_lock = threading.Lock()


    @PartialState().on_main_process
    def build_proposal_dataset(self, qa_pairs, save_to, mode="dpo"):
        def _wrap_data(prompt, answer):
            chats = []
            chats.append({"role": "user", "content": prompt})
            chats.append({"role": "assistant", "content": answer})
            return chats
        
        parent_dir = os.path.dirname(save_to)
        if not os.path.exists(parent_dir):
            print(f"Creating directory at {parent_dir}")
            os.makedirs(parent_dir)
        
        dpo_datasets = []
        sft_datasets = []
        for qa_pair in qa_pairs:
            prompt, positive_texts, negative_texts = qa_pair["prompt"], qa_pair["positive_texts"], qa_pair["negative_texts"]
            if mode == "dpo":
                pair_cnt = min(len(positive_texts), len(negative_texts))
                positive_texts = random.sample(positive_texts, pair_cnt)
                positive_pairs = [_wrap_data(prompt, text) for text in positive_texts]
                negative_texts = random.sample(negative_texts, pair_cnt)
                negative_pairs = [_wrap_data(prompt, text) for text in negative_texts]
                for positive, negative in zip(positive_pairs, negative_pairs):
                    dpo_datasets.append({"rejected": negative, "chosen": positive})
            elif mode == "sft": # only collect positive data for sft
                for positive_text in positive_texts:
                    sft_datasets.append({"prompt": [{"role": "user", "content": prompt}], "response": [{"role": "assistant", "content": positive_text}]})
        if mode == "dpo":
            with open(save_to, "w") as f:
                for data in dpo_datasets:  
                    f.write(json.dumps(data) + "\n")
        elif mode == "sft":
            with open(os.path.join(parent_dir, "train_sft.jsonl"), "w") as f:
                for data in sft_datasets:
                    f.write(json.dumps(data) + "\n")

    @PartialState().on_main_process
    def build_critic_dataset(self, positive_texts, negative_texts, save_to, current_len=None):    
        pos_len, neg_len = len(positive_texts), len(negative_texts)
        labels = - torch.ones(pos_len + neg_len)
        labels[:pos_len] = 1. 
        
        input_texts = positive_texts + negative_texts   
        temp_dataset = Dataset.from_dict({
                "texts": input_texts,
                "labels": labels
            }).with_format("torch")
        batch_dataset = temp_dataset.map(
            lambda x: self.tokenizer(
                    x["texts"], 
                    return_tensors="pt", 
                    padding=True,
                    truncation=True, 
                    add_special_tokens=self.config["add_special_tokens"],
                ),
            remove_columns=["texts"],
            batched=True,
            batch_size=64)
        
        if current_len:
            save_to = os.path.join(save_to, f"length_{current_len}")
            os.makedirs(save_to, exist_ok=True)
        batch_dataset.save_to_disk(save_to)
        print(f"\nDataset saved to {save_to}\n")
        
    def build_dataloader(self, batch_dataset):
        data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
        dataloader_params = {
            "batch_size": self.config["batch_size"],
            "collate_fn": data_collator,
            "num_workers": 0,
            "pin_memory": True,
            "shuffle": True,
        }
        # TODO: check the sampler category if available
        batch_dataloader = self.accelerator.prepare(DataLoader(batch_dataset, **dataloader_params))
        return batch_dataloader

    def compute_loss(self, model, inputs, return_outputs=False):
        # forward pass
        labels = inputs.pop("labels").type(torch.LongTensor).to(self.accelerator.device)
        
        inputs = inputs.to(self.accelerator.device)
        outputs = model(**inputs)
        
        alpha = self.config["l2_reg_coef"]
        energy_temp = self.config["energy_temp"]
        l2_loss = 0.
        output_logits = outputs.get("logits") # (B, 1)
        energies = - output_logits.squeeze(-1) # (B)
            
        pos_energy = energies[labels > 0] / energy_temp
        neg_energy = energies[labels < 0] / energy_temp
        
        # when batch size is really small, pos_energy and neg_energy can be empty
        # which will cause nan loss
        if pos_energy.shape[0] == 0:
            pos_energy = torch.zeros(1).to(self.accelerator.device)
        if neg_energy.shape[0] == 0:
            neg_energy = torch.zeros(1).to(self.accelerator.device)
        
        ml_loss = pos_energy.mean() - neg_energy.mean()
        
        if alpha != 0:
            l2_loss = alpha * energies.square().mean()
            
        loss = ml_loss + l2_loss
                
        self.accelerator.log({"total_loss": loss.item()})
        self.accelerator.log({"l2_loss": l2_loss.item() if alpha > 0 else 0.})
        self.accelerator.log({"ml_loss": ml_loss.item()})
        self.accelerator.log({"pos_energy": pos_energy.mean().item()})
        self.accelerator.log({"neg_energy": neg_energy.mean().item()})        
        
        return (loss, outputs) if return_outputs else loss
    
    def train_step(self, train_loader):
        
        progress_bar = tqdm(range(len(train_loader)), desc="Training", disable=not self.accelerator.is_local_main_process)
        avg_loss = 0.
        
        self.model.train()
        for _, batch in enumerate(train_loader):
            with self.accelerator.accumulate(self.model):
                
                loss = self.compute_loss(
                    model=self.model, 
                    inputs=batch,
                )
                
                avg_loss += loss.item()
                
                self.accelerator.backward(loss)
                
                # print(self.accelerator.state.distributed_type)
                if self.accelerator.sync_gradients and self.accelerator.state.distributed_type != DistributedType.FSDP:
                    grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.accelerator.log({"gradient_norm": grad_norm.mean()})
                    self.accelerator.log({"avg_loss": avg_loss / self.accelerator.gradient_accumulation_steps})
                    avg_loss = 0.

                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()

            progress_bar.update(1)
            progress_bar.set_description(f"Loss: {loss.item():.3f}")

            self.accelerator.log({"learning_rate": self.lr_scheduler.get_last_lr()[0]})
            self.accelerator.log({"update_step": _})
                  
        release_memory()
        
    def get_scores_from_texts(self, input_texts):
        inputs = self.tokenizer(
            input_texts, 
            return_tensors="pt", 
            add_special_tokens=self.config["add_special_tokens"], 
            padding=True,
            truncation=True,
        ).to(self.accelerator.device)
        
        with self.model_lock:
            self.model.eval()
            with torch.no_grad(): 
                outputs = self.model(**inputs)
                    
            # (B, 1)
            output_logits = outputs.get("logits") 
            return output_logits.detach().squeeze(-1)          
        
    def get_embeddings_from_texts(self, input_texts):
        def _unwrap_model(model):
            if isinstance(model, nn.parallel.DistributedDataParallel) or \
                isinstance(model, nn.parallel.DataParallel):
                return model.module
            return model
        
        # whether normamlize output of the pool layer
        normalize = self.config.get("normalize", False)
        inputs = self.tokenizer(
            input_texts, 
            return_tensors="pt", 
            add_special_tokens=self.config["add_special_tokens"], 
            padding=True,
            truncation=True,
        ).to(self.accelerator.device)
        
        with self.model_lock:
            self.model.eval()
            model = _unwrap_model(self.model)
            with torch.no_grad():
                if self.mode == "classification":
                    if hasattr(model, "deberta"):
                        deberta_outputs = model.deberta(**inputs)

                        sequence_outputs = deberta_outputs.get("last_hidden_state") # (B, L, 1024)
                        pooled_output = model.pooler(sequence_outputs) # (B, 1024)
                    elif hasattr(model, "model"):
                        outputs = model.model(**inputs)
                        last_hidden_state = outputs.last_hidden_state
                        if model.config.classifier_pooling == "cls":
                            last_hidden_state = last_hidden_state[:, 0]
                        elif model.config.classifier_pooling == "mean":
                            last_hidden_state = (
                                last_hidden_state * inputs["attention_mask"].unsqueeze(-1)
                            ).sum(1) / inputs["attention_mask"].sum(dim=1, keepdim=True)
                        pooled_output = model.head(last_hidden_state)
                        pooled_output = model.drop(pooled_output)
                    else:
                        raise NotImplementedError(f"Model is not implemented")
                elif self.mode.startswith("generation"):
                    pooled_output = model(**inputs, get_pooler_output=True).get("pooler_output")

                if normalize:
                    pooled_output = torch.nn.functional.normalize(pooled_output, p=2, dim=-1)
                    
                return pooled_output.detach()   