import torch
import torch.optim as optim
import numpy as np
import os
import pandas as pd
from pathlib import Path
import gc
from scipy.stats import pearsonr

from es_llm.trainer.trainer import Trainer
from es_llm.task import create_task
from es_llm.model import create_model
from es_llm.dataset import create_dataset
from es_llm.tokenizer import create_tokenizer
from es_llm.dataloader import create_dataloader
from es_llm.collator import create_collator
from es_llm.generator import create_generator
from es_llm.scheduler import create_scheduler




class GRPOTrainer(Trainer):
    def __init__(self, config, seed, model_weights=None):
        super().__init__(config, seed) 

        self.set_seed()

        os.environ["CUDA_VISIBLE_DEVICES"] = self.config["gpu"]["use"]

        self.learning_steps = self.config["trainer"]["parameters"]["learning_steps"]
        self.disable_dropout = self.config["trainer"]["parameters"]["disable_dropout"]
        self.num_epochs = self.config["trainer"]["parameters"]["num_epochs"]
        self.epsilon = self.config["trainer"]["parameters"]["epsilon"]
        self.beta = self.config["trainer"]["parameters"]["beta"]
        self.loss_type = self.config["trainer"]["parameters"]["loss_type"]
        self.num_generations = self.config["trainer"]["parameters"]["num_generations"]
        self.max_tokens = self.config["generator"]["max_new_tokens"]
        self.learning_rate = float(
            self.config["trainer"]["parameters"]["learning_rate"]
        )
        self.save_every = self.config["trainer"]["parameters"]["save_every"]

        self.task = create_task(name=config["task"]["name"], config=config["task"])
        self.dataset = create_dataset(
            name=config["dataset"]["name"], config=config["dataset"]
        )
        self.dataloader = create_dataloader(
            self.dataset,
            batch_size=self.config["trainer"]["parameters"]["batch_size"],
            shuffle=False,
        )
        self.tokenizer = create_tokenizer(
            name=config["tokenizer"]["name"], config=config["tokenizer"]
        )
        self.data_collator = create_collator(
            name=config["collator"]["name"], tokenizer=self.tokenizer
        )
        self.policy_model = create_model(
            name=config["policy_model"]["name"], config=config["policy_model"]
        )
        self.reference_model = create_model(
            name=config["reference_model"]["name"], config=config["reference_model"]
        )
        self.generator = create_generator(
            config=self.config["generator"], tokenizer=self.tokenizer)

        if self.disable_dropout:
            self.disable_dropout_in_model(self.policy_model)
            self.disable_dropout_in_model(self.reference_model)

        self.policy_model.train()
        self.reference_model.eval()

        

        self.optimizer = optim.AdamW(
            self.policy_model.parameters(), lr=self.learning_rate
        )

        try:
            self.lr_scheduler = create_scheduler(name=self.config["trainer"]["scheduler"]["name"],
                                                config=self.config["trainer"]["scheduler"],
                                                optimizer=self.optimizer)
        except:
            print("No scheduler found: Using default no-op scheduler")
            self.lr_scheduler = create_scheduler(optimizer=self.optimizer,
                                                 )

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.pad_token = self.tokenizer.pad_token
        self.pad_token_id = self.tokenizer.pad_token_id
        self.eos_token_id = self.tokenizer.eos_token_id

        self.log = pd.DataFrame(
            columns=[
                "iteration",
                "step",
                "epoch",
                "step_lr",
                "loss",
                "mean_kl",
                "kl_fw_direct",
                "reward_mean",
                "reward_min",
                "reward_max",
                "r_t_mean",
                "r_t_min",
                "r_t_max",
                "ref_r_t_mean",
                "ref_r_t_min",
                "ref_r_t_max",
                "completion_length_mean",
                "completion_length_min",
                "completion_length_max",
                
                
            ]
        )

    def compute_kl_divergence(self, policy_logps, reference_logps):

        return (
            torch.exp(reference_logps - policy_logps)
            - (reference_logps - policy_logps)
            - 1
        )
    
    def disable_dropout_in_model(self, model: torch.nn.Module):
        for module in model.modules():
            if isinstance(module, torch.nn.Dropout):
                module.p = 0
    
    
    def clear_gpu_memory(self):

        del self.policy_model
        del self.reference_model
        del self.optimizer

        self.policy_model, self.reference_model, self.optimizer = None, None, None

        # distributed teardown
        import torch.distributed as dist
        if dist.is_initialized():
            dist.barrier(); dist.destroy_process_group()

        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()          
        torch.cuda.reset_peak_memory_stats()

        # Optional: clean Inductor/Dynamo caches if you used torch.compile
        try:
            import torch._dynamo as dynamo
            dynamo.reset()
        except Exception:
            pass


    def run(self, output_dir):

        # start at learning step 0 and increment every policy update
        self.output_dir = output_dir
        iteration, learning_step = 0, 0
        while learning_step < self.learning_steps:

            for (input_text, target_text) in iter(self.dataloader):

                batch_target_text = [
                    item for item in target_text for _ in range(self.num_generations)
                ]
                # prepare the input text for model input (tokenization, etc.)
                batch = self.data_collator(input_text).to(self.policy_model.device)

                # get previous policy model samples for current batch                
                labels = self.generator.sample(self.policy_model, batch)                

                epoch_loss, epoch_kl = [], []

                # isolate the tokens for the response from the model by removing the prompt tokens
                prompt_length = batch["input_ids"].shape[1]
                response_only_tokens = labels[:, prompt_length:].clone()

                # (batch, sequence_length) bool tensor: True where pad tokens are located
                is_eos = response_only_tokens == self.tokenizer.pad_token_id

                # Find the index of the first PAD in each row
                # If PAD is not present, set to sequence_length
                batch_size, sequence_length = response_only_tokens.shape
                first_eos_idx = torch.where(
                    is_eos.any(dim=1),
                    is_eos.float().argmax(dim=1),
                    torch.full((batch_size,), sequence_length).to(
                        self.policy_model.device
                    ),
                )
                # create a tensor of position indices
                position_ids = (
                    torch.arange(sequence_length)
                    .unsqueeze(0)
                    .to(self.policy_model.device)
                )
                # create a mask where position <= first_pad_idx is 1, else 0
                # we need the first <|end_of_text|> token to be 1, as we want to compute the attention on this token
                # this will tell the model where to stop generating tokens
                response_only_mask = (position_ids <= first_eos_idx.unsqueeze(1)).long()
                labels_mask = torch.cat(
                    [
                        batch["attention_mask"].repeat_interleave(
                            self.num_generations, dim=0
                        ),
                        response_only_mask,
                    ],
                    dim=1,
                )

                eos_token_mask = torch.where(
                    labels_mask[:, prompt_length:] == 0,
                    self.tokenizer.pad_token_id,
                    labels_mask[:, prompt_length:],
                )
                labels_w_eos = torch.where(
                    eos_token_mask != 1,
                    self.tokenizer.pad_token_id,
                    labels[:, prompt_length:],
                )
                
                # vsz = self.tokenizer.vocab_size
                # labels_w_eos = torch.clamp(labels_w_eos, min=0, max=vsz-1)  # or replace invalid with pad_id
                # labels_w_eos[(labels_w_eos < 0) | (labels_w_eos >= vsz)] = self.tokenizer.pad_token_id
                
                string_text = self.tokenizer.batch_decode(
                                labels_w_eos, skip_special_tokens=True, clean_up_tokenization_spaces=False
                                                            )
                print(string_text[0], string_text[-1])  
            

                rewards, __ = self.task(
                    generated_text=string_text, target_text=batch_target_text
                )

                # adjust the labels to ensure token alignment with logits
                response_only_tokens = response_only_tokens[:, 1:]
                response_only_mask = response_only_mask[:, 1:]

                prev_policy_batch = {"input_ids": labels, "attention_mask": labels_mask}

                # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
                # run with torch.no_grad() (use_no_grad = True)
                prev_policy_model_per_token_logp = self.generator.get_per_token_logp(
                    model=self.policy_model,
                    batch=prev_policy_batch,
                    labels=response_only_tokens,
                    prompt_length=prompt_length,
                    batch_size=self.num_generations,
                    use_no_grad=True,
                    detach=True,
                )

                # get logits from the reference model with respect to the tokens sampeld under the previous policy of the current batch
                # tensor shape -> (batch_size, prompt_length:sequence_length-1, vocab_size)
                reference_model_per_token_logp = self.generator.get_per_token_logp(
                    model=self.reference_model,
                    batch=prev_policy_batch,
                    labels=response_only_tokens,
                    prompt_length=prompt_length,
                    batch_size=self.num_generations,
                    use_no_grad=True,
                    detach=True,
                )
                # Compute grouped-wise rewards
                mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(
                    dim=1
                )
                std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
                # Normalize the rewards to compute the advantages
                mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
                    self.num_generations, dim=0
                )
                std_grouped_rewards = std_grouped_rewards.repeat_interleave(
                    self.num_generations, dim=0
                )
                advantages = rewards - mean_grouped_rewards
                advantages = advantages / (std_grouped_rewards + 1e-4)

                # get logits from the current policy model with respect to the tokens sampled under the previous policy of the current batch
                # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
                for epoch in range(self.num_epochs):

                    policy_model_per_token_logp = self.generator.get_per_token_logp(
                        model=self.policy_model,
                        batch=prev_policy_batch,
                        labels=response_only_tokens,
                        batch_size=self.num_generations,
                        prompt_length=prompt_length,
                        use_no_grad=False,
                        detach=False,
                    )

                    logp_ratio = (
                        policy_model_per_token_logp - prev_policy_model_per_token_logp
                    )
                    # calculate the exponential of the logp ratio
                    exp_logp_ratio = torch.exp(logp_ratio)

                    # calculate GRPO loss
                    coef1 = exp_logp_ratio * advantages.unsqueeze(1).to(
                        self.policy_model.device
                    )
                    coef2 = torch.clamp(
                        exp_logp_ratio,
                        min=1 - self.epsilon,
                        max=1 + self.epsilon,
                    ) * advantages.unsqueeze(1).to(self.policy_model.device)

                    print(advantages[0], advantages[-1])

                    kl = self.compute_kl_divergence(
                        policy_model_per_token_logp, reference_model_per_token_logp
                    )

                    per_token_loss = -torch.min(coef1, coef2) + self.beta * kl

                    if self.loss_type == "grpo":
                        loss = (
                            (per_token_loss * response_only_mask.float()).sum(-1)
                            / response_only_mask.float().sum(-1).clamp(min=1.0)
                        ).mean()
                    elif self.loss_type == "bnpo":
                        loss = (
                            per_token_loss * response_only_mask.float()
                        ).sum() / response_only_mask.float().sum().clamp(min=1.0)
                    elif self.loss_type == "dr_grpo":
                        loss = (per_token_loss * response_only_mask.float()).sum() / (
                            per_token_loss.size(0) * self.max_tokens
                        )

                    # NEW: low-variance forward-KL metric for logging
                    with torch.no_grad():
                        maskf = response_only_mask.float()
                        kl_fw_direct = ((policy_model_per_token_logp - reference_model_per_token_logp) * maskf).sum() / maskf.sum()


                    # update the policy model
                    torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), 1.0)
                    loss.backward()
                    self.optimizer.step()
                    self.lr_scheduler.step()
                    # self.scheduler.step()
                    self.optimizer.zero_grad()

                    # epoch logging stats
                    mean_kl_per_response = (
                        kl * response_only_mask
                    ).sum() / response_only_mask.sum()



                    # step_lr = self.optimizer.param_groups[0]['lr']
                    step_lr = self.lr_scheduler.get_last_lr()[0]
                    print(
                        f"iteration: {iteration+1}; step {learning_step+1}; epoch {epoch+1} ratio mean: {exp_logp_ratio.mean().item()}, kl {mean_kl_per_response.mean().item()}, lr {step_lr}"
                    )

                    epoch_loss.append(loss.item())
                    epoch_kl.append(mean_kl_per_response.item())
                    ref_exp_logp_ratio = torch.exp(
                        policy_model_per_token_logp - reference_model_per_token_logp
                    )

                    
                    # ---- Logging ----                  

                    log_entry = {
                        "iteration": iteration + 1,
                        "step": learning_step + 1,
                        "epoch": epoch + 1,
                        "step_lr" : step_lr,
                        "loss": loss.item(),
                        "mean_kl": mean_kl_per_response.item(),
                        "kl_fw_direct": kl_fw_direct.item(),
                        "reward_mean": rewards.mean().item(),
                        "reward_min": rewards.min().item(),
                        "reward_max": rewards.max().item(),
                        "r_t_mean": exp_logp_ratio.mean().item(),
                        "r_t_min": exp_logp_ratio.min().item(),
                        "r_t_max": exp_logp_ratio.max().item(),
                        "ref_r_t_mean": ref_exp_logp_ratio.mean().item(),
                        "ref_r_t_min": ref_exp_logp_ratio.min().item(),
                        "ref_r_t_max": ref_exp_logp_ratio.max().item(),
                        "completion_length_mean": response_only_mask.sum(-1)
                        .float()
                        .mean()
                        .item(),
                        "completion_length_min": response_only_mask.sum(-1)
                        .float()
                        .min()
                        .item(),
                        "completion_length_max": response_only_mask.sum(-1)
                        .float()
                        .max()
                        .item(),
                        
                    }

                    self.log = pd.concat(
                        [self.log, pd.DataFrame([log_entry])], ignore_index=True
                    )

                    self.log.to_csv(
                        f"{self.output_dir}/log/grpo_qwen2_train_log_steps{self.learning_steps}_max_tokens{self.max_tokens}_beta{self.beta}_epsilon{self.epsilon}_lr{self.learning_rate}_loss{self.loss_type}_seed{self.seed}.csv"
                    )

                    learning_step += 1
                    step_model_name = f"grpo_qwen2_train_step{learning_step}_steps_{self.learning_steps}_max_tokens{self.max_tokens}_beta{self.beta}_epsilon{self.epsilon}_lr{self.learning_rate}_loss{self.loss_type}_seed{self.seed}"
                    if learning_step % self.save_every == 0:
                        Path(output_dir + f"/model/{step_model_name}").mkdir(
                            parents=True, exist_ok=True
                        )
                        self.policy_model.save_pretrained(
                            f"{self.output_dir}/model/{step_model_name}"
                        )

                    if learning_step >= self.learning_steps:
                        break

                    iteration += 1

                print(
                    f"iteration: {iteration}; mean loss: {sum(epoch_loss)/len(epoch_loss)}; max reward: {rewards.max()}; min reward: {rewards.min()}; mean rewards: {rewards.mean()}, mean kl: {sum(epoch_kl)/len(epoch_kl)}, mean completion length {response_only_mask.sum(-1).float().mean()}"
                )
        # delete the models and clear cache
        self.clear_gpu_memory()

        return self.log
