import torch
import numpy as np
import os
import copy
import pandas as pd

from es_llm.trainer.trainer import Trainer
from es_llm.task import create_task
from es_llm.model import create_model
from es_llm.dataset import create_dataset
from es_llm.tokenizer import create_tokenizer
from es_llm.dataloader import create_dataloader

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"


class ESLLMTrainerOrig(Trainer):
    def __init__(self, config):
        self.config = config
        self.log = pd.DataFrame()
        self.num_interations = self.config["trainer"]["parameters"]["num_iterations"]
        self.population_size = self.config["trainer"]["parameters"]["population_size"]
        self.sigma = self.config["trainer"]["parameters"]["sigma"]
        self.alpha = self.config["trainer"]["parameters"]["alpha"]
        self.max_new_tokens = self.config["trainer"]["parameters"]["max_new_tokens"]
        self.do_sample = self.config["trainer"]["parameters"]["do_sample"]
        self.batch_size = self.config["trainer"]["parameters"]["batch_size"]
        self.output_directory = self.config["trainer"]["experiment_directory"]

        self.task = create_task(name=config["task"]["name"], config=config["task"])
        self.dataset = create_dataset(
            name=config["dataset"]["name"], config=config["dataset"]
        )
        self.dataloader = create_dataloader(
            self.dataset, batch_size=self.batch_size, shuffle=False
        )
        self.tokenizer = create_tokenizer(
            name=config["tokenizer"]["name"], config=config["tokenizer"]
        )
        self.model = create_model(name=config["model"]["name"], config=config["model"])
        self.model.eval()

    def evaluate_model(self, input_text, target_text):
        """
        Generate a response from the model given an input and compute a reward.
        """
        input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to('cuda')
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids, max_new_tokens=self.max_new_tokens, do_sample=self.do_sample
            )
        try:
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except TypeError:
            # fallback: manually convert ids→tokens, filter None, then join
            tokens = self.tokenizer.convert_ids_to_tokens(
                outputs[0], skip_special_tokens=True
            )
            filtered = [t for t in tokens if t is not None]
            generated_text = self.tokenizer.convert_tokens_to_string(filtered)
        print(generated_text)
        reward, __ = self.task(generated_text, target_text)
        print(reward)
        return reward

    def run(self):

        base_state = copy.deepcopy(self.model.state_dict())
        log_rewards, log_prompts, log_response = [], [], []

        for iteration in range(self.num_interations):
            rewards = []
            noise_samples = (
                []
            )  # List of dictionaries; each dict maps parameter names to noise tensors
            # Sample one seed per population member
            seeds = np.random.randint(
                0, 2**31, size=self.population_size, dtype=np.int64
            ).tolist()

            for i, seed in enumerate(seeds):
                perturbed_state = {}
                # gen = torch.Generator()  # CPU‐side generator
                for name, base_param in base_state.items():
                    gen = torch.Generator(device=base_param.device)
                    # re-seed and generate exactly the same noise every time
                    gen.manual_seed(int(seed))
                    # draw noise in the same device/shape as the parameter
                    noise = torch.randn(
                        base_param.shape,
                        generator=gen,
                        device=base_param.device,
                        dtype=base_param.dtype,
                    )
                    perturbed_state[name] = base_param + self.sigma * noise

                self.model.load_state_dict(perturbed_state)
                # Evaluate on the dataset and compute average reward.
                total_reward = 0.0

                for (input_text, target_text) in iter(self.dataloader):
                    reward = self.evaluate_model(input_text, target_text)
                    total_reward += reward
                average_reward = total_reward / len(self.dataset)
                rewards.append(average_reward)

            # After evaluation, restore the base weights.
            self.model.load_state_dict(base_state)

            # Convert rewards to a tensor and normalize.
            rewards_tensor = np.array(rewards, dtype=np.float32)
            rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (
                rewards_tensor.std() + 1e-8
            )

            # now build aggregated_update with only seeds + rewards_norm
            aggregated_update = {}
            for name, base_param in base_state.items():
                update = torch.zeros_like(base_param)
                gen = torch.Generator(device=base_param.device)
                for r_norm, seed in zip(rewards_normalized, seeds):
                    gen.manual_seed(int(seed))
                    noise = torch.randn(
                        base_param.shape,
                        generator=gen,
                        device=base_param.device,
                        dtype=base_param.dtype,
                    )
                    update += noise * float(r_norm)
                update /= self.population_size
                aggregated_update[name] = update

            # Update base weights using the ES update rule.
            for name in base_state.keys():
                base_state[name] = (
                    base_state[name] + self.alpha * aggregated_update[name]
                )

            # Load the updated weights back into the model.
            self.model.load_state_dict(base_state)

            # Log the progress
            mean_reward = rewards_tensor.mean().item()
            print(
                f"Iteration {iteration + 1}/{self.num_interations}, Mean Reward: {mean_reward:.4f}"
            )

        # After ES, save the fine-tuned model weights.
        self.model.save_pretrained(
            "{}/models/finetuned_qwen_es_random_seed_pop{}_iter{}_sigma{}_alpha{}".format(
                self.experiment_directory,
                self.population_size,
                self.num_interations,
                self.sigma,
                self.alpha,
            )
        )
        self.tokenizer.save_pretrained("{}/models/finetuned_qwen_es").format(
            self.experiment_directory
        )
