import os
import argparse
import torch
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, TrainerCallback
from tqdm import tqdm
import numpy as np
import cvxpy as cp

# Import custom modules
from models.qwen_lora_model import QwenLoraModule
from dataset.DataLoader import get_distillation_dataloader
from dataset import SNLIDataset, MTBenchDataset, SummEvalDataset


class DistributionAlignmentTrainer(Trainer):
    """Trainer for distribution alignment"""
    
    def __init__(self, alpha=0.8, epsilon=0.25, 
                 adv_steps=5, adv_lr=0.05, use_cvxpy=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha  # Weight for KL divergence loss
        self.epsilon = epsilon  # Maximum scale of perturbation
        self.adv_steps = adv_steps  # Number of adversarial training steps
        self.adv_lr = adv_lr  # Learning rate for adversarial perturbation
        self.use_cvxpy = use_cvxpy  # Whether to use CVXPY for precise projection
    
    def initialize_perturbed_distribution(self, human_probs):
        """
        Randomly initialize perturbed probability distribution as PGD starting point
        """
        noise = torch.rand_like(human_probs) * 2 * self.epsilon - self.epsilon
        perturbed_probs = human_probs + noise
        perturbed_probs = torch.clamp(perturbed_probs, min=1e-7)
        perturbed_probs = perturbed_probs / perturbed_probs.sum()
        
        return perturbed_probs.clone().detach().requires_grad_(True)
    
    def project_to_constraints(self, perturbed_probs, orig_probs):
        """
        Project updated distribution to constraint space:
        1. L2 distance no more than epsilon
        2. Maintain non-negativity
        3. Sum equals 1
        """
        if self.use_cvxpy:
            # Use CVXPY for precise projection
            return self._project_cvxpy(perturbed_probs.detach().cpu().numpy(), 
                                      orig_probs.detach().cpu().numpy(),
                                      device=perturbed_probs.device)
        else:
            # Use approximate projection method
            return self._project_approximate(perturbed_probs, orig_probs)
    
    def _project_approximate(self, perturbed_probs, orig_probs):
        """Approximate projection method (NOT USED)"""
        with torch.no_grad():
            # 1. Project to L2 ball: ||p - orig_p||_2 <= epsilon
            delta = perturbed_probs - orig_probs
            norm = torch.norm(delta)
            if norm > self.epsilon:
                delta = delta * (self.epsilon / norm)
                perturbed_probs = orig_probs + delta    
            # 2. Ensure non-negativity
            perturbed_probs = torch.clamp(perturbed_probs, min=1e-7)
            # 3. Normalize to ensure sum equals 1
            perturbed_probs = perturbed_probs / perturbed_probs.sum()
            
            return perturbed_probs
    
    def _project_cvxpy(self, perturbed_probs_np, orig_probs_np, device=None):
        
        n = len(orig_probs_np)
        x = cp.Variable(n)

        # 1. Set objective function: minimize L2 distance
        objective = cp.Minimize(cp.norm2(x - perturbed_probs_np))

        # 2. Set constraints: in probability simplex and L2 distance not exceeding ε
        constraints = [
            cp.sum(x) == 1,
            x >= 0,
            cp.norm2(x - orig_probs_np) <= self.epsilon
        ]

        # 3. Solve the problem
        prob = cp.Problem(objective, constraints)
        try:
            prob.solve(solver=cp.ECOS)
            if x.value is None:
                raise ValueError("CVXPY returned None")
            result = torch.tensor(x.value, dtype=torch.float32)
            if device is not None:
                result = result.to(device)
            # Ensure numerical stability
            result = torch.clamp(result, min=1e-7)
            result = result / result.sum()
            return result
        except Exception as e:
            return self._project_approximate(
                torch.tensor(perturbed_probs_np, device=device),
                torch.tensor(orig_probs_np, device=device)
            )
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):

        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.logits
        batch_size = logits.shape[0]

        human_distributions = inputs.pop("human_distribution", None)
        choices = inputs.pop("choices", None)
        attention_mask = inputs["attention_mask"]
        
        kl_batch_losses = []
        ce_batch_losses = []
        
        for i in range(batch_size):

            seq_len = attention_mask[i].sum().item()
            if seq_len == 0:
                continue
            last_token_logits = logits[i, seq_len - 1, :]
            
            # 1. Prepare LLM probability distribution
            choice_token_ids = []
            current_human_probs = []
            for choice in choices[i]:
                choice_tokens = self.processing_class.encode(choice, add_special_tokens=False)  
                choice_token_ids.append(choice_tokens[0])
                current_human_probs.append(human_distributions[i][choice])
            
            choice_token_ids = torch.tensor(choice_token_ids, device=last_token_logits.device)
            output_prob = last_token_logits[choice_token_ids]
            
            # 2. Prepare human probability distribution
            human_probs = torch.tensor(current_human_probs, device=output_prob.device, dtype=torch.float)
            if human_probs.sum() > 0:
                human_probs = human_probs / human_probs.sum()
            
            
            # 3. Calculate KL divergence loss
            if self.alpha == 0:
                kl_div = F.kl_div(
                        F.log_softmax(output_prob, dim=0),
                        human_probs,
                        reduction='sum'
                    )
                kl_batch_losses.append(kl_div)
            
            elif self.alpha > 0:
                if self.epsilon == 0:
                    kl_div = F.kl_div(
                        F.log_softmax(output_prob, dim=0),
                        human_probs,
                        reduction='sum'
                    )
                    kl_batch_losses.append(kl_div)
                else:
                    # 3.1 Initialize perturbed distribution (PGD starting point)
                    perturbed_probs = self.initialize_perturbed_distribution(human_probs)
                    
                    # 3.2 Calculate initial KL divergence
                    model_log_probs = F.log_softmax(output_prob.detach(), dim=0)
                    initial_kl_div = F.kl_div(model_log_probs, perturbed_probs.detach(), reduction='sum')
                    best_kl_div = initial_kl_div.item()
                    best_perturbed_probs = perturbed_probs.clone().detach()

                    # 3.3 Execute PGD adversarial optimization
                    for step in range(self.adv_steps):
                        # Calculate KL divergence
                        model_log_probs = F.log_softmax(output_prob.detach(), dim=0)
                        kl_div = F.kl_div(model_log_probs, perturbed_probs, reduction='sum')
                        
                        # Gradient ascent
                        adv_loss = -kl_div
                        adv_loss.backward()
                        
                        # Update perturbed distribution
                        with torch.no_grad():
                            grad_norm = torch.norm(perturbed_probs.grad)
                            if grad_norm > 1e-8:  # Avoid division by zero
                                normed_grad = perturbed_probs.grad / grad_norm
                                perturbed_probs.data = perturbed_probs.data - self.adv_lr * normed_grad
                            # Project back to constraint space
                            perturbed_probs.data = self.project_to_constraints(perturbed_probs.data, human_probs)
            
                        perturbed_probs.grad.zero_()
                        
                        # Calculate new KL divergence
                        new_kl_div = F.kl_div(
                            F.log_softmax(output_prob.detach(), dim=0),
                            perturbed_probs.detach(),
                            reduction='sum'
                        ).item()

                        # Check if KL divergence increased
                        if new_kl_div > best_kl_div:
                            best_kl_div = new_kl_div
                            best_perturbed_probs = perturbed_probs.clone().detach()
                        else:
                            break

                    # Use the worst-case perturbation found to calculate loss
                    with torch.no_grad():
                        final_perturbed_probs = best_perturbed_probs.detach()
                    
                    # 3.4 Calculate final KL divergence loss
                    kl_div = F.kl_div(
                        F.log_softmax(output_prob, dim=0),
                        final_perturbed_probs,
                        reduction='sum'
                    )
                    kl_batch_losses.append(kl_div)
            
            # 4. Calculate CE loss
            max_prob_idx = torch.argmax(human_probs)
            target_token_id = choice_token_ids[max_prob_idx]
            ce = F.cross_entropy(
                last_token_logits.unsqueeze(0),
                torch.tensor([target_token_id], device=last_token_logits.device)
            )
            ce_batch_losses.append(ce)
        
        # 5. Hybrid loss
        if kl_batch_losses and ce_batch_losses:
            kl_loss = torch.stack(kl_batch_losses).mean()
            ce_loss = torch.stack(ce_batch_losses).mean()
        
            loss = self.alpha * kl_loss + (1 - self.alpha) * ce_loss
        else:
            loss = torch.tensor(0.0, device=logits.device)
        
        return (loss, outputs) if return_outputs else loss


class TqdmWithLossAverage(TrainerCallback):
    """Callback to display average loss during training"""
    
    def __init__(self):
        self.training_bar = None
        self.current_step = 0
        self.losses = []
        
    def on_train_begin(self, args, state, control, **kwargs):
        total_steps = state.max_steps
        self.training_bar = tqdm(total=total_steps, desc="Training")
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            steps_done = state.global_step - self.current_step
            self.current_step = state.global_step
            self.training_bar.update(steps_done)
            
            current_loss = logs["loss"]
            self.losses.append(current_loss)
            
            recent_avg = np.mean(self.losses[-10:]) if len(self.losses) >= 10 else np.mean(self.losses)
            total_avg = np.mean(self.losses)
            
            self.training_bar.set_postfix(
                loss=f"{current_loss:.4f}",
                avg_10=f"{recent_avg:.4f}",
                avg_all=f"{total_avg:.4f}"
            )
    
    def on_train_end(self, args, state, control, **kwargs):
        self.training_bar.close()


def train_distribution_alignment(
        raw_dataset, 
        output_dir, 
        model_path=None, 
        training_args=None, 
        batch_size=4,
        alpha=0.5,
        epsilon=0.1,
        adv_steps=5,
        adv_lr=0.1,
        max_length=512):
    """
    Distribution alignment training for LLM-as-a-judge
    
    Args:
        raw_dataset: Original dataset object
        output_dir: Output directory
        model_path: Model path, default is the path in the configuration file
        training_args: Training parameters
        batch_size: Batch size
        alpha: Weight for KL divergence loss, default is 0.5
        epsilon: Maximum scale of perturbation (perturbation_scale)
        adv_steps: Number of adversarial training steps
        adv_lr: Learning rate for adversarial perturbation
        max_length: Maximum length of input sequence
    """
    model_module = QwenLoraModule(
        model_path=model_path,
        inference_mode=False
    )
    tokenizer = model_module.tokenizer
    train_dataloader = get_distillation_dataloader(raw_dataset, tokenizer, batch_size=batch_size, max_length=max_length)
    if training_args is None:
        raise ValueError("training_args not provided")
    
    trainer = DistributionAlignmentTrainer(
        alpha=alpha,
        epsilon=epsilon,
        adv_steps=adv_steps,
        adv_lr=adv_lr,
        model=model_module.model,
        args=training_args,
        train_dataset=train_dataloader.dataset,
        tokenizer=tokenizer,
        data_collator=train_dataloader.collate_fn,
        callbacks=[TqdmWithLossAverage()]
    )
    
    trainer.train()

    os.makedirs(output_dir, exist_ok=True)
    model_module.save_adapter(os.path.join(output_dir, "final_checkpoint"))
    
    return model_module


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Distribution Alignment Training for LLM-as-a-judge")
    parser.add_argument("--output_dir", type=str, default=os.path.join(os.path.dirname(__file__), "outputs", "qwen2.5-7b"), help="Output directory")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch Size")
    parser.add_argument("--accumulate_grad_batches", type=int, default=8, help="Gradient Accumulation Batches")
    parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--model_path", type=str, default=os.path.join(os.path.dirname(__file__), "models", "Qwen", "Qwen2.5-7B-Instruct"), help="Model path")
    parser.add_argument("--alpha", type=float, default=0.5, help="Weight for kl and ce loss")
    parser.add_argument("--epsilon", type=float, default=0, help="The radius of the perturbation")
    parser.add_argument("--adv_steps", type=int, default=5, help="Number of adversarial training steps")
    parser.add_argument("--adv_lr", type=float, default=0.05, help="Learning rate for adversarial perturbation")
    parser.add_argument("--dataset", type=str, default="SNLI", 
                      choices=["SNLI", "MultiNLI", "MTBench", "SummEval"],
                      help="Training dataset name")
    
    args = parser.parse_args()
    
    # Max sequence length for corresponding dataset
    dataset_max_lengths = {
        "SNLI": 256,
        "MultiNLI": 400,
        "MTBench": 3072,
        "SummEval": 1100,
    }
    
    # Load corresponding dataset based on selected dataset name
    dataset_mapping = {
        "SNLI": SNLIDataset(file_path=os.path.join(os.path.dirname(__file__), "dataset", "train", "snli_train.jsonl")),
        "MultiNLI": SNLIDataset(file_path=os.path.join(os.path.dirname(__file__), "dataset", "train", "multinli_train.jsonl")),
        "MTBench": MTBenchDataset(file_path=os.path.join(os.path.dirname(__file__), "dataset", "train", "mt_bench_train.jsonl")),
        "SummEval": SummEvalDataset(file_path=os.path.join(os.path.dirname(__file__), "dataset", "train", "summeval_train.jsonl"))
    }
    
    # Get selected dataset and max length
    selected_dataset = dataset_mapping[args.dataset]
    max_length = dataset_max_lengths[args.dataset]
    accumulation_steps = max(1, args.accumulate_grad_batches // args.batch_size)
    
    print(f"Using dataset: {args.dataset}, max sequence length: {max_length}")
    
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        gradient_accumulation_steps=accumulation_steps,
        per_device_train_batch_size=args.batch_size,
        save_strategy="epoch",
        logging_dir=os.path.join(args.output_dir, "logs"),
        logging_steps=20,
        logging_strategy="steps",
        log_level="info",
        learning_rate=args.learning_rate,
        remove_unused_columns=False,
        report_to="tensorboard",
        optim="adamw_torch",
        use_liger_kernel=True,
        warmup_ratio=0.05,
    )

    model_module = train_distribution_alignment(
        raw_dataset=selected_dataset,
        output_dir=args.output_dir,
        model_path=args.model_path,
        training_args=training_args,
        batch_size=args.batch_size,
        alpha=args.alpha,
        epsilon=args.epsilon,
        adv_steps=args.adv_steps,
        adv_lr=args.adv_lr,
        max_length=max_length
    )
    
    print(f"Training completed, model saved to: {args.output_dir}")