import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset
import numpy as np
from fine_llm import read_text_files
import torch
import numpy as np
import sacrebleu


def init_teacher_models(names):
    teacher_models = {
     name:AutoModelForCausalLM.from_pretrained(name) for name in names
    }
    
    return init_teacher_models


def calculate_crps(teacher_probs, student_probs):
    # Sort the probabilities for both teacher and student
    sorted_teacher_probs = np.sort(teacher_probs)
    sorted_student_probs = np.sort(student_probs)
    
    # Calculate the cumulative distribution function (CDF) for both teacher and student
    cdf_teacher = np.cumsum(sorted_teacher_probs)
    cdf_student = np.cumsum(sorted_student_probs)
    
    # Calculate the absolute difference between CDFs
    abs_diff_cdf = np.abs(cdf_teacher - cdf_student)
    
    # Calculate the CRPS
    crps = np.trapz(abs_diff_cdf, dx=1)
    
    return crps

def calculate_gini_coefficient(data):
    # First, sort the data
    sorted_data = np.sort(data)
    
    # Calculate the cumulative sum of the sorted data
    cumulative_sum = np.cumsum(sorted_data)
    
    # Calculate the Gini coefficient
    n = len(data)
    gini_index = (n + 1 - 2 * np.sum(cumulative_sum) / cumulative_sum[-1]) / n
    
    return gini_index

def calculate_pearson_coefficient(x, y):
    # Calculate the mean of x and y
    mean_x = np.mean(x)
    mean_y = np.mean(y)
    
    # Calculate the numerator (sum of the product of deviations)
    numerator = np.sum((x - mean_x) * (y - mean_y))
    
    # Calculate the denominator (product of standard deviations)
    denominator = np.sqrt(np.sum((x - mean_x)**2) * np.sum((y - mean_y)**2))
    
    # Calculate the Pearson correlation coefficient
    pearson_coefficient = numerator / denominator
    
    return pearson_coefficient

def calculate_herfindahl_index(market_shares):
    # Calculate the squares of market shares
    squared_shares = [share**2 for share in market_shares]
    
    # Calculate the Herfindahl Index as the sum of squared shares
    herfindahl_index = sum(squared_shares)
    
    return herfindahl_index


def calculate_bleu_score(reference, candidate):
    # reference is a list of reference translations
    # candidate is a list of candidate translations
    
    # Convert the reference and candidate translations to strings
    reference_str = [" ".join(ref) for ref in reference]
    candidate_str = [" ".join(can) for can in candidate]
    
    # Calculate the BLEU Score
    bleu = sacrebleu.corpus_bleu(candidate_str, [reference_str])
    
    return bleu.score

class Agent:
    def __init__(self, num_teachers, alpha, epsilon):
        self.num_teachers = num_teachers
        self.alpha = alpha
        self.epsilon = epsilon
        self.weights = [1.0 / num_teachers] * num_teachers  # 初始权重均匀分布
        self.policy_network = nn.Sequential(
            nn.Linear(num_teachers * 2, 128),  # Input is concatenated teacher and student probabilities
            nn.ReLU(),
            nn.Linear(128, num_teachers),
            nn.Softmax(dim=-1)
        )
        self.optimizer = torch.optim.Adam(self.policy_network.parameters(), lr=0.001)
        
    def compute_reward(self, teachers_output, student_output):
        # Implement the reward calculation (F1(t)) based on the provided description
        # You need to calculate CRPS, Gini coefficient, Pearson correlation coefficient, Herfindahl Index, and BLEU score
        # Define the reward function F1(t) here
        reward = F1(teachers_output, student_output)  # Replace with your reward calculation
        return reward
    
    def update_weight(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32)
        action_probs = self.policy_network(state_tensor)
        action_dist = Categorical(action_probs)
        action = action_dist.sample()  # Sample an action based on the policy
        
        # Compute the log probability of the selected action
        log_prob = action_dist.log_prob(action)
        
        return action.item(), log_prob


teachers_model = init_teacher_models('save_models/1_model','save_models/2_model','save_models/3_model','save_models/4_model')

class KnowledgeDistillationLoss(torch.nn.Module):
    def __init__(self, num_teachers,  beta=0.5, lambda_adv=1.0):
        super(KnowledgeDistillationLoss, self).__init__()
        self.num_teachers = num_teachers
        # self.w = w
        self.beta = beta
        self.lambda_adv = lambda_adv

    def forward(self, w,student_outputs, teacher_outputs, ground_truth_logits, adversarial_perturbation):
        # Compute the similarity loss (L_similarity)
        mse_loss = torch.nn.MSELoss()
        cross_entropy_loss = torch.nn.CrossEntropyLoss()

        similarity_loss = 0.0
        for i in range(self.num_teachers):
            similarity_loss += (w[i] * (mse_loss(teacher_outputs[i], student_outputs) +
                                             cross_entropy_loss(teacher_outputs[i], student_outputs)))

        # Compute the translation loss (L_translation)
        translation_loss = cross_entropy_loss(ground_truth_logits, student_outputs)

        # Compute the adversarial loss (L_adv)
        adversarial_loss = torch.max(cross_entropy_loss(student_outputs, teacher_outputs) - 
                                     self.lambda_adv * torch.norm(adversarial_perturbation) ** 2, torch.tensor(0.0))

        # Combine the losses
        total_loss = self.beta * similarity_loss + (1 - self.beta) * adversarial_loss

        return total_loss
    


def F1(teachers_output, student_output):
    # Placeholder for reward calculation
    # This is a simplified example; you should replace this with your actual reward logic
    reward = 0.0
    
    # Calculate CRPS, Gini coefficient, Pearson correlation coefficient, Herfindahl Index, BLEU score, etc.
    # Replace these placeholders with your calculations
    
    # For example, you can compute the CRPS as follows:
    crps = calculate_crps(teachers_output, student_output)
    
    # For the Gini coefficient, replace this with your implementation:
    gini_coefficient = calculate_gini_coefficient(teachers_output)
    
    # For the Pearson correlation coefficient, replace this with your implementation:
    pearson_coefficient = calculate_pearson_coefficient(teachers_output, student_output)
    
    # For the Herfindahl Index, replace this with your implementation:
    herfindahl_index = calculate_herfindahl_index(teachers_output)
    
    # For the BLEU score, replace this with your implementation:
    bleu_score = calculate_bleu_score(teachers_output, student_output)
    
    # Combine the metrics using your defined equation
    reward = -crps - gini_coefficient - pearson_coefficient + herfindahl_index + alpha * bleu_score
    
    return reward


num_teachers = len(teachers_model)
# Initialize the agent
agent = Agent(num_teachers, alpha, epsilon)

# Initialize the knowledge distillation loss function
distillation_loss_fn = KnowledgeDistillationLoss(num_teachers, beta, lambda_adv)

# Training loop
for epoch in range(num_epochs):
    for data in dataset:
        # Agent initializes weights
        agent_weights = [1.0 / num_teachers] * num_teachers
        
        # Forward pass: Calculate knowledge distillation loss
        teacher_outputs = []  # List to store teacher model outputs
        adversarial_perturbations = []  # List to store adversarial perturbations
        
        for teacher_model in teacher_models:
            teacher_model.eval()
            teacher_output = teacher_model(data)  # Use teacher model to get outputs
            teacher_outputs.append(teacher_output)
            
            # Calculate adversarial perturbation (you need to implement this)
            adversarial_perturbation = calculate_adversarial_perturbation(data, teacher_output)
            adversarial_perturbations.append(adversarial_perturbation)
        
        student_output = student_model(data)  # Get student model's output
        
        # Calculate the knowledge distillation loss
        distillation_loss = distillation_loss_fn(agent_weights, student_output, teacher_outputs, ground_truth_logits, adversarial_perturbations)
        
        # Compute the reward
        reward = agent.compute_reward(teacher_outputs, student_output)
        
        # Backpropagate and update agent's policy
        loss = -distillation_loss  # Minimize negative distillation loss
        loss.backward()
        agent.optimizer.step()
        
        # Update agent's weights (you need to implement this)
        action, log_prob = agent.update_weight(state)
        
        # Update teacher weights for the next iteration
        agent_weights[action] += epsilon
        
        # Perform any additional updates or logging as needed
