import torch
import torch.optim as optim
import numpy as np
import os
import copy
import pandas as pd
from copy import deepcopy

from es_llm.evaluator.evaluator import Evaluator
from es_llm.task import create_task
from es_llm.model import create_model
from es_llm.dataset import create_dataset
from es_llm.tokenizer import create_tokenizer
from es_llm.dataloader import create_dataloader
from es_llm.collator import create_collator
from es_llm.generator import create_generator




class ModelEvaluator(Evaluator):
    def __init__(self, config, model_weights=None):
        self.config = config

        self.experiment_directory = self.config["experiment"]["experiment_directory"]
        self.num_generations = self.config["evaluator"]["num_generations"]
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config["gpu"]["use"]
        for task in self.config["task"]:
            if task["name"] == "kl_divergence":
                self.kl_divergence = create_task(name=task["name"])
            else:

                self.reward_function = create_task(name=task["name"])
        self.dataset = create_dataset(
            name=config["dataset"]["name"], config=config["dataset"]
        )
        self.dataloader = create_dataloader(
            self.dataset,
            batch_size=self.config["evaluator"]["batch_size"],
            shuffle=False,
        )
        self.tokenizer = create_tokenizer(
            name=config["tokenizer"]["name"], config=config["tokenizer"]
        )
        self.data_collator = create_collator(
            name=config["collator"]["name"], tokenizer=self.tokenizer
        )

        self.trained_model = create_model(
            name=config["trained_model"]["name"],
            config=config["trained_model"],
            model_weights=model_weights,
        )

        self.reference_model = create_model(
            name=config["reference_model"]["name"], config=config["reference_model"]
        )
        self.generator = create_generator(
            config=self.config["generator"], tokenizer=self.tokenizer
        )

        self.trained_model  # .eval()
        self.reference_model.eval()

        self.log = pd.DataFrame(
            columns=[
                "prompt",
                "completion",
                "mean_kl",
                "max_kl",
                "min_kl",
                "reward",
                "norm_reward",
                "completion_length",
            ]
        )

    def run(self):

        for (input_text, target_text) in iter(self.dataloader):

            batch_target_text = [
                item for item in target_text for _ in range(self.num_generations)
            ]
            # prepare the input text for model input (tokenization, etc.)
            batch = self.data_collator(input_text).to(self.trained_model.device)

            # get previous policy model samples for current batch
            labels = self.generator.sample(self.trained_model, batch)

            # isolate the tokens for the response from the model by removing the prompt tokens
            prompt_length = batch["input_ids"].shape[1]
            response_only_tokens = labels[:, prompt_length:].clone()

            # (batch, sequence_length) bool tensor: True where pad tokens are located
            is_pad = response_only_tokens == self.tokenizer.pad_token_id

            # Find the index of the first PAD in each row
            # If PAD is not present, set to sequence_length
            batch_size, sequence_length = response_only_tokens.shape
            first_pad_idx = torch.where(
                is_pad.any(dim=1),
                is_pad.float().argmax(dim=1),
                torch.full((batch_size,), sequence_length).to(
                    self.trained_model.device
                ),
            )
            # create a tensor of position indices
            position_ids = (
                torch.arange(sequence_length).unsqueeze(0).to(self.trained_model.device)
            )
            # create a mask where position <= first_pad_idx is 1, else 0
            # we need the first <|end_of_text|> token to be 1, as we want to compute the attention on this token
            # this will tell the model where to stop generating tokens
            response_only_mask = (position_ids <= first_pad_idx.unsqueeze(1)).long()
            labels_mask = torch.cat(
                [
                    batch["attention_mask"].repeat_interleave(
                        self.num_generations, dim=0
                    ),
                    response_only_mask,
                ],
                dim=1,
            )

            eos_token_mask = torch.where(
                labels_mask[:, prompt_length:] == 0,
                self.tokenizer.pad_token_id,
                labels_mask[:, prompt_length:],
            )
            labels_w_eos = torch.where(
                eos_token_mask != 1,
                self.tokenizer.pad_token_id,
                labels[:, prompt_length:],
            )

            # convert the sampled token_ids to tokens, and score them under the reward function
            generated_text = [
                self.tokenizer.convert_ids_to_tokens(i, skip_special_tokens=True)
                for i in labels_w_eos
            ]

            # print(labels_w_eos)
            # print("")
            # print(generated_text)

            string_text = [
                self.tokenizer.convert_tokens_to_string(i) for i in generated_text
            ]

            rewards, norm_rewards = self.reward_function(
                generated_text=string_text, target_text=batch_target_text
            )

            # adjust the labels to ensure token alignment with logits
            response_only_tokens = response_only_tokens[:, 1:]
            response_only_mask = response_only_mask[:, 1:]

            prev_policy_batch = {"input_ids": labels, "attention_mask": labels_mask}

            # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
            trained_model_per_token_logp = self.generator.get_per_token_logp(
                model=self.trained_model,
                batch=prev_policy_batch,
                labels=response_only_tokens,
                prompt_length=prompt_length,
                batch_size=8,
                use_no_grad=True,
                detach=True,
            )

            # get logits from the reference model with respect to the tokens sampeld under the previous policy of the current batch
            # tensor shape -> (batch_size, prompt_length:sequence_length-1, vocab_size)
            reference_model_per_token_logp = self.generator.get_per_token_logp(
                model=self.reference_model,
                batch=prev_policy_batch,
                labels=response_only_tokens,
                prompt_length=prompt_length,
                batch_size=8,
                use_no_grad=True,
                detach=True,
            )

            kl = self.kl_divergence(
                trained_model_per_token_logp, reference_model_per_token_logp
            )

            kl_per_sample = (kl * response_only_mask).sum(
                dim=1
            ) / response_only_mask.sum(dim=1).clamp(min=1)
            completion_lengths = response_only_mask.sum(dim=1).cpu()

            batch_input_text_expanded = [
                text for text in input_text for _ in range(self.num_generations)
            ]

            batch_log = pd.DataFrame(
                {
                    "prompt": batch_input_text_expanded,  # One row per string
                    "completion": string_text,  # Fill if you track completions separately
                    "mean_kl": kl_per_sample.cpu().tolist(),
                    "max_kl": kl.max(dim=1).values.cpu().tolist(),
                    "min_kl": kl.min(dim=1).values.cpu().tolist(),
                    "reward": rewards.cpu().tolist(),
                    "norm_reward" : norm_rewards.cpu().tolist(),  # Same as reward_mean if you use scalar reward
                    "completion_length": completion_lengths.float().tolist(),
                }
            )

            # Append to log
            self.log = pd.concat([self.log, batch_log], ignore_index=True)

        return self.log
