import torch
import numpy as np
import os
import copy
import pandas as pd

from es_llm.trainer.trainer import Trainer
from es_llm.task import create_task
from es_llm.model import create_model
from es_llm.dataset import create_dataset
from es_llm.tokenizer import create_tokenizer
from es_llm.dataloader import create_dataloader
from es_llm.collator import create_collator
from es_llm.generator import create_generator




class ESLLMTrainerOrig(Trainer):
    def __init__(self, config, seed):
        # self.config = config
        # self.seed = seed
        super().__init__(config, seed) 
        self.log = pd.DataFrame()
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config["gpu"]["use"]
        self.set_seed()
        self.num_interations = self.config["trainer"]["parameters"]["num_iterations"]
        self.population_size = self.config["trainer"]["parameters"]["population_size"]
        self.sigma = self.config["trainer"]["parameters"]["sigma"]
        self.alpha = float(self.config["trainer"]["parameters"]["alpha"])
        self.max_new_tokens = self.config["trainer"]["parameters"]["max_new_tokens"]
        self.do_sample = self.config["trainer"]["parameters"]["do_sample"]
        self.batch_size = self.config["trainer"]["parameters"]["batch_size"]
        self.save_every = self.config["trainer"]["parameters"]["save_every"]

        self.task = create_task(name=config["task"]["name"], config=config["task"])
        self.dataset = create_dataset(
            name=config["dataset"]["name"], config=config["dataset"]
        )
        self.dataloader = create_dataloader(
            self.dataset, batch_size=self.batch_size, shuffle=False
        )
        self.tokenizer = create_tokenizer(
            name=config["tokenizer"]["name"], config=config["tokenizer"]
        )

        self.data_collator = create_collator(
            name=config["collator"]["name"], tokenizer=self.tokenizer
        )

        self.generator = create_generator(
            config=self.config["generator"], tokenizer=self.tokenizer
        )
        self.model = create_model(name=config["model"]["name"], config=config["model"])
        self.model.eval()

    def evaluate_model(self, input_text, target_text):
        """
        Generate a response from the model given an input and compute a reward.
        """
        batch = self.data_collator(input_text).to(self.model.device)
        prompt_length = batch["input_ids"].shape[1]

        with torch.no_grad():
            labels = self.generator.sample(self.model, batch)
        response_only_tokens = labels[:, prompt_length:].clone()
        # generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # (batch, sequence_length) bool tensor: True where pad tokens are located
        is_eos = response_only_tokens == self.tokenizer.pad_token_id

        # Find the index of the first PAD in each row
        # If PAD is not present, set to sequence_length
        batch_size, sequence_length = response_only_tokens.shape
        first_eos_idx = torch.where(
            is_eos.any(dim=1),
            is_eos.float().argmax(dim=1),
            torch.full((batch_size,), sequence_length).to(
                self.model.device
            ),
        )
        # create a tensor of position indices
        position_ids = (
            torch.arange(sequence_length)
            .unsqueeze(0)
            .to(self.model.device)
        )

        response_only_mask = (position_ids <= first_eos_idx.unsqueeze(1)).long()
        
        
        eos_token_mask = torch.where(
                    response_only_mask == 0,
                    self.tokenizer.pad_token_id,
                    response_only_mask,
                )
        
        labels_w_eos = torch.where(
                    eos_token_mask != 1,
                    self.tokenizer.pad_token_id,
                    labels[:, prompt_length:],
                )

        # test = [
        #     self.tokenizer.convert_ids_to_tokens(i, skip_special_tokens=False)
        #     for i in labels_w_eos
        # ]
        generated_text = [
            self.tokenizer.convert_ids_to_tokens(i, skip_special_tokens=True)
            for i in labels_w_eos
        ]

        string_text = [
            self.tokenizer.convert_tokens_to_string(i) for i in generated_text
        ]

        reward, __ = self.task(string_text, target_text)
        
        #print(test)
        print(f"Generated Text: {string_text}")
        print(f"Reward: {reward}")
        print("")

        return generated_text, reward

    def run(self, output_dir):

        base_state = copy.deepcopy(self.model.state_dict())
        log_iterations, log_rewards, log_prompts, log_responses = [], [], [], []

        for iteration in range(self.num_interations):
            rewards = []
            # Sample one seed per population member
            seeds = np.random.randint(
                0, 2**31, size=self.population_size, dtype=np.int64
            ).tolist()

            for i, seed in enumerate(seeds):
                perturbed_state = {}
                for name, base_param in base_state.items():
                    gen = torch.Generator(device=base_param.device)
                    # re-seed and generate exactly the same noise every time
                    gen.manual_seed(int(seed))
                    # draw noise in the same device/shape as the parameter
                    noise = torch.randn(
                        base_param.shape,
                        generator=gen,
                        device=base_param.device,
                        dtype=base_param.dtype,
                    )
                    perturbed_state[name] = base_param + self.sigma * noise

                self.model.load_state_dict(perturbed_state)
                # Evaluate on the dataset and compute average reward.
                total_reward = 0.0

                for (input_text, target_text) in iter(self.dataloader):
                    generated_text, reward = self.evaluate_model(
                        input_text, target_text
                    )
                    total_reward += sum(reward)
                    # update the log for each prompt
                    log_iterations.append(iteration), log_prompts.append(
                        input_text
                    ), log_responses.append(generated_text), log_rewards.append(reward)

                average_reward = total_reward / len(self.dataset)
                rewards.append(average_reward)

            # After evaluation, restore the base weights.
            self.model.load_state_dict(base_state)

            # Convert rewards to a tensor and normalize.
            rewards_tensor = np.array(rewards, dtype=np.float32)
            rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (
                rewards_tensor.std() + 1e-8
            )

            # now build aggregated_update with only seeds + rewards_norm
            aggregated_update = {}
            for name, base_param in base_state.items():
                update = torch.zeros_like(base_param)
                gen = torch.Generator(device=base_param.device)
                for r_norm, seed in zip(rewards_normalized, seeds):
                    gen.manual_seed(int(seed))
                    noise = torch.randn(
                        base_param.shape,
                        generator=gen,
                        device=base_param.device,
                        dtype=base_param.dtype,
                    )
                    update += noise * float(r_norm)
                update /= self.population_size
                aggregated_update[name] = update

            # Update base weights using the ES update rule.
            for name in base_state.keys():
                base_state[name] = (
                    base_state[name] + self.alpha * aggregated_update[name]
                )

            # Load the updated weights back into the model.
            self.model.load_state_dict(base_state)

            # Log the progress
            mean_reward = rewards_tensor.mean().item()
            print(
                f"Iteration {iteration + 1}/{self.num_interations}, Mean Reward: {mean_reward:.4f}"
            )

            # After ES, save the fine-tuned model weights.
            if (iteration + 1) % self.save_every == 0:

                self.model.save_pretrained(
                    "{}/model/finetuned_qwen_es_random_seed_pop{}_iter{}_sigma{}_alpha{}".format(
                        str(output_dir),
                        self.population_size,
                        iteration+1,
                        self.sigma,
                        self.alpha,
                    )
                )

        self.log["iteration"] = log_iterations
        self.log["prompts"] = log_prompts
        self.log["generations"] = log_responses
        self.log["rewards"] = log_rewards
        self.log.to_csv(
            "{}/log/finetuned_qwen_es_random_seed_pop{}_iter{}_sigma{}_alpha{}.csv".format(
                str(output_dir),
                self.population_size,
                self.num_interations,
                self.sigma,
                self.alpha,
            ),
            index=False,
        )

        return self.log
