"""
Common Components for Federated Learning and Bandit Algorithms

This module provides shared components and utilities for federated learning
experiments, including API interfaces, data processing, and evaluation tools.
"""

import json
import random
import torch
import numpy as np
import sys
import os
import nest_asyncio
import matplotlib.pyplot as plt
from torch.optim import SGD
import pandas as pd
from tqdm import tqdm
from torch.optim import LBFGS
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

nest_asyncio.apply()
current_dir = os.path.dirname(os.path.abspath(__file__))

experiments_dir = os.path.join(current_dir, '..')  # Assume experiments directory is one level up
sys.path.append(experiments_dir)

from automatic_prompt_engineer import config, llm

cwd = os.getcwd()
sys.path.append(cwd)

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from torch import nn
from backpack import backpack, extend
from automatic_prompt_engineer import ape, data
from data.instruction_induction.load_data import load_data
from evaluation.instruction_induction.exec_accuracy import \
    exec_accuracy_evaluator, exec_evaluator
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from automatic_prompt_engineer import evaluate, template, data
import re

import argparse
from experiments.evaluation.instruction_induction.utility import set_all_seed
import datetime
import torch.nn.functional as F

# API Keys - Support for OpenAI and OpenRouter
# Note: Set these environment variables before running the script
# os.environ["OPENAI_API_KEY"] = "your_openai_key"
# os.environ["OPENROUTER_API_KEY"] = "your_openrouter_key"

SMOKE_TEST = os.environ.get("SMOKE_TEST")
tkwargs = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}

model_name = "vicuna"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
api_model = 'chatgpt'
alpha = 1
sigma = 1


class FederatedLinearDuelingBanditEnvironment:
    """
    Federated Linear Dueling Bandit Environment.

    This class simulates a federated learning environment for dueling bandit problems,
    where multiple clients participate in preference-based learning with contextual arms.
    """

    def __init__(self, feature_dim, num_arms, num_clients, noise=0.1):
        """
        Initialize the federated dueling bandit environment.

        Args:
            feature_dim: Dimension of feature vectors
            num_arms: Number of arms (actions) available
            num_clients: Number of federated clients
            noise: Noise level for preference generation (default: 0.1)
        """
        self.feature_dim = feature_dim
        self.num_arms = num_arms
        self.num_clients = num_clients
        self.noise = noise

    def generate_context(self):
        """
        Generate a set of feature vectors (arms) for all clients.

        Returns:
            List of context matrices, one for each client
        """
        return [torch.randn(self.num_arms, self.feature_dim, device="cuda") for _ in range(self.num_clients)]

    def get_preference(self, arm1_idx, arm2_idx, score_list):
        """
        Simulate a pairwise comparison between two arms with sub-Gaussian (logistic) noise.

        Args:
            arm1_idx: Index of the first arm
            arm2_idx: Index of the second arm
            score_list: List of true scores for all arms

        Returns:
            int: 1 if arm1 is preferred, 0 otherwise
        """
        # Get utilities for both arms
        arm1_utility = torch.tensor(score_list[arm1_idx])
        arm2_utility = torch.tensor(score_list[arm2_idx])

        diff = (arm1_utility - arm2_utility) * self.noise

        y_prob = torch.sigmoid(diff)

        y_prob_val = y_prob.detach().cpu().item()

        y = np.random.binomial(n=1, p=y_prob_val)

        return y


def extract_sub_sentence(long_sentence):
    """
    Extract prompts from XML-like tags in a sentence.

    Args:
        long_sentence: Input string containing <prompt>...</prompt> tags

    Returns:
        List of extracted prompt strings
    """
    matches = re.findall('<prompt>(.*?)</prompt>', long_sentence)
    return matches


def mean_pooling(model_output, attention_mask):
    """
    Mean pooling - Take attention mask into account for correct averaging.

    Args:
        model_output: Output from transformer model
        attention_mask: Attention mask tensor

    Returns:
        Pooled sentence embeddings
    """
    token_embeddings = model_output[0]  # First element contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_sen_embedding(model, tokenizer, sentences):
    """
    Generate sentence embeddings using a pre-trained transformer model.

    Args:
        model: Pre-trained transformer model
        tokenizer: Corresponding tokenizer
        sentences: List of sentences to encode

    Returns:
        Normalized sentence embeddings tensor
    """
    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings


class LMForwardAPI:
    """
    Language Model Forward API for Prompt Engineering.

    This class provides an interface for interacting with language models
    in the context of automated prompt engineering and evaluation.
    """

    def __init__(self, model_name='vicuna', eval_data=None, init_prompt=None, init_qa_gen=None, conf=None,
                 base_conf=None, prompt_gen_data=None, n_prompt_tokens=None, few_shot_data=None,
                 random_proj=None, intrinsic_dim=None, magnitude=None, norm_method=None, eval_extra_msg=''):
        """
        Initialize the Language Model Forward API.

        Args:
            model_name: Name of the language model to use (default: 'vicuna')
            eval_data: Evaluation dataset
            init_prompt: Initial prompt template
            init_qa_gen: Initial Q&A generation function
            conf: Configuration dictionary
            base_conf: Base configuration file path
            prompt_gen_data: Data for prompt generation
            n_prompt_tokens: Number of prompt tokens
            few_shot_data: Few-shot learning data
            random_proj: Random projection method
            intrinsic_dim: Intrinsic dimension for dimensionality reduction
            magnitude: Magnitude parameter
            norm_method: Normalization method
            eval_extra_msg: Extra message for evaluation
        """
        self.init_qa_gen = init_qa_gen
        self.init_prompt = init_prompt[0]
        init_qa = self.init_qa_gen()
        self.init_token = init_prompt[0] + init_qa
        self.count = 0
        self.eval_extra_msg = eval_extra_msg

        ## eval preparation
        self.conf = config.update_config(conf, base_conf)
        self.eval_data = eval_data
        self.eval_template = template.EvalTemplate("Instruction: [PROMPT]\n\nInput: [INPUT]\n Output: [OUTPUT]")
        self.demos_template = template.DemosTemplate("Input: [INPUT]\nOutput: [OUTPUT]")

        if api_model in ['llama', 'flan-t5']:
            self.api_model = exec_evaluator(api_model, self.conf)

        if few_shot_data is None:
            self.few_shot_data = prompt_gen_data

        self.best_train_perf = 0.0
        self.best_dev_perf = 0.0
        self.best_last_perf = 10
        self.best_prompt = None
        self.num_call = 0
        self.best_instruction = None
        self.prompts_set = dict()
        self.prompts_list = []
        self.parents = []
        self.best_score = 0
        self.score_mean = None
        self.score_std = None
        self.score_min = None
        self.score_max = None
        self.magnitude = magnitude
        self.norm_method = norm_method
        self.init_user_prompt = None

    def update_init_token(self):
        """
        Update the initial token by randomly choosing a Q&A pair.
        """
        init_qa = self.init_qa_gen()
        self.init_token = self.init_prompt + init_qa

    def initialize_prompts(self, num_init, task, method):
        """
        Initialize a set of prompts using the specified method.

        Args:
            num_init: Number of initial prompts to generate
            task: Task name for prompt generation
            method: Method to use ('induction' or 'rephrase')

        Returns:
            List of initialized prompt strings
        """
        ini_prompts_his = {}
        print(self.conf['generation']['model'])
        model = llm.model_from_config(self.conf['generation']['model'])
        if method == 'rephrase':
            model_outputs = model.generate_text(self.init_token, 1, 0.5)
            ini_prompts_his[model_outputs[0]] = 0
            self.init_user_prompt = model_outputs[0]
        while len(ini_prompts_his) < num_init:
            if method == 'induction':
                if task in ['sum', 'first_word_letter', 'periodic_elements', 'active_to_passive']:
                    random_prompt = model.generate_text(self.init_token, 1, 1, use_seed=False)[0]
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + random_prompt + "\n the rephrased instruction is: ", 1,
                        1, use_seed=False)
                else:
                    model_outputs = model.generate_text(self.init_token, 1, 0.5)
                ini_prompts_his[model_outputs[0]] = 0
                self.update_init_token()
                print(f'{task}: {len(ini_prompts_his)}')
            elif method == 'rephrase':
                if task in ['odd_one_out', 'orthography_starts_with']:
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + self.init_user_prompt + "\n the rephrased instruction is: ",
                        1, 1.5, use_seed=False)
                else:
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + self.init_user_prompt + "\n the rephrased instruction is: ",
                        1, 1, use_seed=False)
                ini_prompts_his[model_outputs[0]] = 0
                print(f'{task}: {len(ini_prompts_his)}')
        return list(ini_prompts_his.keys())

    def selection(self, num_next_gen):
        """
        Select parent pairs for evolution based on their performance scores.

        Args:
            num_next_gen: Number of next generation prompts to generate

        Returns:
            List of parent pairs for crossover
        """
        scores = np.array([self.prompts_set[tmp] for tmp in self.parents])
        num_parents = len(self.parents)
        probability = []
        if np.sum(scores) == 0:
            probability = np.ones(num_parents) / num_parents
        else:
            probability = scores / np.sum(scores)

        all_parents = []
        for i in range(num_next_gen):
            try:
                parent_pair = np.random.choice(self.parents, size=2, replace=False, p=probability)
            except:
                parent_pair = np.random.choice(self.parents, size=2, replace=True, p=probability)
            all_parents += [parent_pair]
        return all_parents

    def evolution(self, all_parents):
        """
        Evolve new prompts through crossover and mutation of parent prompts.

        Args:
            all_parents: List of parent pairs for evolution

        Returns:
            List of evolved prompt strings
        """
        next_gens = []
        model = llm.model_from_config(self.conf['evaluation']['model'])

        template = "Please follow the instruction step-by-step to generate a better prompt.\n1. Cross over the following prompts and generate a new prompt:\nPrompt 1: [prompt_id1].\nPrompt 2: [prompt_id2].\n2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>."
        for parents_ in all_parents:
            template_ = template.replace('[prompt_id1]', parents_[0])
            template_ = template_.replace('[prompt_id2]', parents_[1])
            model_outputs = model.generate_text(template_, 1, 0)
            model_outputs_ = extract_sub_sentence(model_outputs[0])
            if len(model_outputs_) != 0:
                model_outputs = model_outputs_[0]
                print(f"EVOL: {model_outputs}")
            else:
                model_outputs = model_outputs[0]
            next_gens += [model_outputs]
        return next_gens

    def update(self, next_gens):
        """
        Update the parent population with evolved prompts based on performance.

        Args:
            next_gens: List of evolved prompt strings
        """
        next_gens_scores = []
        for gen_ in next_gens:
            score_ = self.eval([gen_])
            next_gens_scores += [score_]
        self.this_iter_best = np.max(next_gens_scores)
        num_parents = len(self.parents)
        parents_next_gen = self.parents + next_gens
        all_scores = [self.prompts_set[tmp] for tmp in parents_next_gen]
        idx_rank = np.argsort(all_scores)
        selected_idx = idx_rank[-num_parents:]
        new_parents = []
        for idx_ in selected_idx:
            new_parents += [parents_next_gen[idx_]]
        self.parents = new_parents

    def eval(self, instruction=None, test=False):
        """
        Evaluate a prompt instruction and return its performance score.

        Args:
            instruction: List containing the instruction string to evaluate
            test: Whether this is a test evaluation (default: False)

        Returns:
            Performance score of the instruction
        """
        if instruction[0] in self.prompts_set.keys():
            dev_perf = self.prompts_set[instruction[0]]
        else:
            if api_model in ['chatgpt']:
                print(self.eval_extra_msg)
                dev_perf, _ = exec_accuracy_evaluator(instruction, self.eval_template, self.eval_data,
                                                      self.demos_template, self.few_shot_data, self.conf['evaluation'],
                                                      self.eval_extra_msg)
                dev_perf = dev_perf.sorted()[1][0]
            else:
                raise NotImplementedError

            if not test:
                if dev_perf >= self.best_last_perf:
                    self.count += 1

                if dev_perf >= self.best_dev_perf:
                    self.best_dev_perf = dev_perf
                    self.best_instruction = instruction

                if self.norm_method == 'standard':
                    dev_perf = self.magnitude * (dev_perf - self.score_mean) / self.score_std
                elif self.norm_method == 'minmax':
                    dev_perf = self.magnitude * (dev_perf - self.score_min) / (self.score_max - self.score_min)
                self.prompts_set[instruction[0]] = dev_perf
                self.prompts_list.append((len(self.prompts_list), instruction[0], dev_perf))
                print('Dev loss: {}. Dev perf: {}. Best dev perf: {}'.format(
                    round(float(dev_perf), 4),
                    round(float(dev_perf), 4),
                    round(float(self.best_dev_perf), 4)))
                print('********* Done *********')
        return dev_perf

    def return_best_prompt(self):
        """
        Return the best performing prompt instruction.

        Returns:
            Best instruction found so far
        """
        return self.best_instruction

    def return_prompts_set(self):
        """
        Return the dictionary of all evaluated prompts and their scores.

        Returns:
            Dictionary mapping prompts to their performance scores
        """
        return self.prompts_set

    def return_prompts_list(self):
        """
        Return the list of all evaluated prompts with metadata.

        Returns:
            List of tuples containing (index, prompt, score)
        """
        return self.prompts_list
