import torch
import torch.optim as optim
import numpy as np
import os
import copy
import pandas as pd
from copy import deepcopy

from es_llm.evaluator.evaluator import Evaluator
from es_llm.task import create_task
from es_llm.model import create_model
from es_llm.dataset import create_dataset
from es_llm.tokenizer import create_tokenizer
from es_llm.dataloader import create_dataloader
from es_llm.collator import create_collator
from es_llm.generator import create_generator

class NoiseEvaluator(Evaluator):
    def __init__(self, config, model_weights=None):
        self.config = config
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config["gpu"]["use"]

        self.experiment_directory = self.config["experiment"]["experiment_directory"]
        self.num_generations = self.config["evaluator"]["num_generations"]
        self.noise_samples = self.config["evaluator"]["noise_samples"]
        self.sigma = self.config["evaluator"]["sigma"]

        self.seed = self.config["evaluator"]["seed"]
        self.set_seed(self.seed)

        for task in self.config["task"]:
            if task["name"] == "kl_divergence":
                self.kl_divergence = create_task(name=task["name"])
            else:

                self.reward_function = create_task(name=task["name"])
        self.dataset = create_dataset(
            name=config["dataset"]["name"], config=config["dataset"]
        )
        self.dataloader = create_dataloader(
            self.dataset,
            batch_size=self.config["evaluator"]["batch_size"],
            shuffle=False,
        )
        self.tokenizer = create_tokenizer(
            name=config["tokenizer"]["name"], config=config["tokenizer"]
        )
        self.data_collator = create_collator(
            name=config["collator"]["name"], tokenizer=self.tokenizer
        )

        self.trained_model = create_model(
            name=config["trained_model"]["name"],
            config=config["trained_model"],
            model_weights=model_weights,
        )

        self.reference_model = create_model(
            name=config["trained_model"]["name"],
            config=config["trained_model"],
            model_weights=model_weights,
        )

        self.generator = create_generator(
            config=self.config["generator"], tokenizer=self.tokenizer
        )

        self.trained_model.eval()
        self.reference_model.eval()

        self.log = pd.DataFrame(
            columns=[
                "noise_iteration",
                "prompt",
                "completion",
                "mean_kl",
                "max_kl",
                "min_kl",
                "reward",
                "norm_reward",
                "completion_length",
            ]
        )


    def run(self):

        base_state = copy.deepcopy(self.trained_model.state_dict())
        sigmas = [0.0002,  0.0004,  0.0006,  0.0008,  0.001, 0.0012,  0.0014,  0.0016,  0.0018,  0.002]

        for sigma in sigmas:

            for noise_idx in range(self.noise_samples):
                print(f"*** Running Noise Interation: {noise_idx+1} ***")

                # Sample one seed per population member
                seeds = np.random.randint(
                    0, 2**31, size=1, dtype=np.int64
                ).tolist()

                for i, seed in enumerate(seeds):
                    perturbed_state = {}
                    for name, base_param in base_state.items():
                        gen = torch.Generator(device=base_param.device)
                        # re-seed and generate exactly the same noise every time
                        gen.manual_seed(int(seed))
                        # draw noise in the same device/shape as the parameter
                        noise = torch.randn(
                            base_param.shape,
                            generator=gen,
                            device=base_param.device,
                            dtype=base_param.dtype,
                        )
                        perturbed_state[name] = base_param + sigma * noise

                    self.trained_model.load_state_dict(perturbed_state)

                for (input_text, target_text) in iter(self.dataloader):

                    batch_target_text = [
                        item for item in target_text for _ in range(self.num_generations)
                    ]
                    # prepare the input text for model input (tokenization, etc.)
                    batch = self.data_collator(input_text).to(self.trained_model.device)

                    # get previous policy model samples for current batch
                    labels = self.generator.sample(self.trained_model, batch)

                    # isolate the tokens for the response from the model by removing the prompt tokens
                    prompt_length = batch["input_ids"].shape[1]
                    response_only_tokens = labels[:, prompt_length:].clone()

                    # (batch, sequence_length) bool tensor: True where pad tokens are located
                    is_pad = response_only_tokens == self.tokenizer.pad_token_id

                    # Find the index of the first PAD in each row
                    # If PAD is not present, set to sequence_length
                    batch_size, sequence_length = response_only_tokens.shape
                    first_pad_idx = torch.where(
                        is_pad.any(dim=1),
                        is_pad.float().argmax(dim=1),
                        torch.full((batch_size,), sequence_length).to(
                            self.trained_model.device
                        ),
                    )
                    # create a tensor of position indices
                    position_ids = (
                        torch.arange(sequence_length).unsqueeze(0).to(self.trained_model.device)
                    )
                    # create a mask where position <= first_pad_idx is 1, else 0
                    # we need the first <|end_of_text|> token to be 1, as we want to compute the attention on this token
                    # this will tell the model where to stop generating tokens
                    response_only_mask = (position_ids <= first_pad_idx.unsqueeze(1)).long()
                    labels_mask = torch.cat(
                        [
                            batch["attention_mask"].repeat_interleave(
                                self.num_generations, dim=0
                            ),
                            response_only_mask,
                        ],
                        dim=1,
                    )

                    eos_token_mask = torch.where(
                        labels_mask[:, prompt_length:] == 0,
                        self.tokenizer.pad_token_id,
                        labels_mask[:, prompt_length:],
                    )
                    labels_w_eos = torch.where(
                        eos_token_mask != 1,
                        self.tokenizer.pad_token_id,
                        labels[:, prompt_length:],
                    )

                    # convert the sampled token_ids to tokens, and score them under the reward function
                    generated_text = [
                        self.tokenizer.convert_ids_to_tokens(i, skip_special_tokens=True)
                        for i in labels_w_eos
                    ]

                    string_text = [
                        self.tokenizer.convert_tokens_to_string(i) for i in generated_text
                    ]

                    rewards, norm_rewards = self.reward_function(
                        generated_text=string_text, target_text=batch_target_text
                    )

                    print(rewards)

                    # adjust the labels to ensure token alignment with logits
                    response_only_tokens = response_only_tokens[:, 1:]
                    response_only_mask = response_only_mask[:, 1:]

                    prev_policy_batch = {"input_ids": labels, "attention_mask": labels_mask}

                    # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
                    trained_model_per_token_logp = self.generator.get_per_token_logp(
                        model=self.trained_model,
                        batch=prev_policy_batch,
                        labels=response_only_tokens,
                        prompt_length=prompt_length,
                        batch_size=8,
                        use_no_grad=True,
                        detach=True,
                    )

                    # get logits from the reference model with respect to the tokens sampeld under the previous policy of the current batch
                    # tensor shape -> (batch_size, prompt_length:sequence_length-1, vocab_size)
                    reference_model_per_token_logp = self.generator.get_per_token_logp(
                        model=self.reference_model,
                        batch=prev_policy_batch,
                        labels=response_only_tokens,
                        prompt_length=prompt_length,
                        batch_size=8,
                        use_no_grad=True,
                        detach=True,
                    )

                    kl = self.kl_divergence(
                        trained_model_per_token_logp, reference_model_per_token_logp
                    )

                    kl_per_sample = (kl * response_only_mask).sum(
                        dim=1
                    ) / response_only_mask.sum(dim=1).clamp(min=1)
                    completion_lengths = response_only_mask.sum(dim=1).cpu()

                    batch_input_text_expanded = [
                        text for text in input_text for _ in range(self.num_generations)
                    ]

                    batch_log = pd.DataFrame(
                        {   "noise_idx" : [noise_idx] * len(rewards),
                            "prompt": batch_input_text_expanded,  # One row per string
                            "completion": string_text,  # Fill if you track completions separately
                            "mean_kl": kl_per_sample.cpu().tolist(),
                            "max_kl": kl.max(dim=1).values.cpu().tolist(),
                            "min_kl": kl.min(dim=1).values.cpu().tolist(),
                            "reward": rewards.cpu().tolist(),  # Same as reward_mean if you use scalar reward
                            "norm_reward" : norm_rewards.cpu().tolist(),
                            "completion_length": completion_lengths.float().tolist(),
                        }
                    )

                    # Append to log
                    self.log = pd.concat([self.log, batch_log], ignore_index=True)

                self.trained_model.load_state_dict(base_state)

        return self.log
