import torch
import wandb
import pandas as pd
from typing import Optional
from transformers import Trainer, TrainingArguments
from tqdm import tqdm


class CustomTrainingArguments(TrainingArguments):
    def __init__(self, *args, max_steps_per_epoch: Optional[int] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_steps_per_epoch = max_steps_per_epoch


class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.log_history = []
        if "tokenizer" in kwargs:
            if "processing_class" not in kwargs:
                from transformers import ProcessingClass

                self.processing_class = ProcessingClass(tokenizer=kwargs["tokenizer"])

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        loss = outputs["loss"]
        # breakpoint()

        ignore_index = -100
        if hasattr(self, "processing_class") and hasattr(self.processing_class, "tokenizer"):
            if hasattr(self.processing_class.tokenizer, "pad_token_id"):
                ignore_index = self.processing_class.tokenizer.pad_token_id

        self.log_metrics(outputs, inputs, ignore_index=ignore_index)

        return (loss, outputs) if return_outputs else loss

    def log_metrics(self, outputs, inputs, ignore_index=-100):
        if not self.is_world_process_zero():
            return

        # cross_attn_sparcity = outputs["cross_attn_sparcity"]
        # self_attn_sparcity = outputs["self_attn_sparcity"]
        # breakpoint()
        metrics = {
            "train/loss": outputs["loss"].mean().item(),
            # "train/loss_clf": outputs["loss_clf"].mean().item(),
            # "train/loss_rg": outputs["loss_rg"].mean().item(),
            # "train/loss_sparsity": outputs["loss_sparsity"].mean().item(),
            # "train/cross_attn_sparcity(l=1)": cross_attn_sparcity[0].mean().item(),
            # "train/cross_attn_sparsity(l=3)": cross_attn_sparcity[2].mean().item(),
            # "train/cross_attn_sparsity(l=6)": cross_attn_sparcity[5].mean().item(),
            # "train/cross_attn_sparsity": cross_attn_sparcity.mean().item(),
            # "train/self_attn_sparcity(l=1)": self_attn_sparcity[0].mean().item(),
            # "train/self_attn_sparsity(l=3)": self_attn_sparcity[2].mean().item(),
            # "train/self_attn_sparsity(l=6)": self_attn_sparcity[5].mean().item(),
            # "train/self_attn_sparsity": self_attn_sparcity.mean().item(),
        }

        if "labels" in inputs and "logits" in outputs:
            labels = inputs["labels"]
            labels, logits = labels[labels != ignore_index], outputs["logits"][labels != ignore_index]
            predictions = torch.argmax(logits, dim=-1)
            metrics["train/error"] = (predictions != labels).float().mean().item()

        self.log_history.append(metrics)

        # wandb.log(metrics, step=self.state.global_step)
        wandb.log(metrics)
        torch.cuda.empty_cache()
        # breakpoint()

    @torch.no_grad()
    def evaluate_test_greedy(self, tokenizer):
        eval_dataloader = self.get_eval_dataloader()
        total_correct, total_sample = 0, 0
        results = []
        self.model.eval()
        for batch in eval_dataloader:
            output = self.model.greedy_generate(
                encoder_input=batch["encoder_input"].to(self.args.device),
                max_length=self.model.max_sequence_length,
                encoder_attention_mask=None,
                encoder_padding_mask=batch["encoder_padding_mask"].to(self.args.device),
            )
            predictions = output[:, 1:]

            min_len = min(int(predictions.shape[1]), int(batch["labels"].shape[1]))
            predictions = predictions[:, :min_len]
            labels = batch["labels"].to(self.args.device)[:, :min_len]
            correct = 0
            for pred, label in zip(predictions, labels):
                if torch.equal(pred, label):
                    correct += 1

            total_correct += correct
            total_sample += len(labels)

            input_texts = [
                tokenizer.decode(enc_input, skip_special_tokens=True) for enc_input in batch["encoder_input"]
            ]
            target_texts = [tokenizer.decode(label, skip_special_tokens=True) for label in labels]
            pred_texts = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions]
            flags = [torch.all(pred == label).item() for pred, label in zip(predictions, labels)]
            for input_text, target_text, pred_text, flag in zip(input_texts, target_texts, pred_texts, flags):
                results.append(
                    {
                        "input_text": input_text,
                        "target_text": target_text,
                        "pred_text": pred_text,
                        "incorrect_flag": flag,
                    }
                )

        df = pd.DataFrame(results)
        return total_correct / total_sample, df

    @torch.no_grad()
    def evaluate_gpt2(self, tokenizer):
        eval_dataloader = self.get_eval_dataloader()
        total_correct, total_sample = 0, 0
        results = []
        self.model.eval()

        for batch in eval_dataloader:
            input_ids = batch["input_ids"].to(self.args.device)
            attention_mask = batch["attention_mask"].to(self.args.device)
            labels = batch["labels"].to(self.args.device)

            input_lengths = []
            batch_input_ids = []
            batch_attention_masks = []
            valid_tokens_list = []

            for i, label in enumerate(labels):
                valid_tokens = label[label != -100]
                valid_tokens_list.append(valid_tokens)

                input_length = len(label) - len(valid_tokens)
                input_lengths.append(input_length)

                current_input_ids = input_ids[i, :input_length]
                current_attention_mask = attention_mask[i, :input_length]

                batch_input_ids.append(current_input_ids)
                batch_attention_masks.append(current_attention_mask)

            max_input_length = max(input_lengths)
            padded_input_ids = []
            padded_attention_masks = []

            for input_ids_item, attention_mask_item in zip(batch_input_ids, batch_attention_masks):
                if len(input_ids_item) < max_input_length:
                    pad_length = max_input_length - len(input_ids_item)
                    padded_input = torch.cat(
                        [
                            input_ids_item,
                            torch.full(
                                (pad_length,),
                                tokenizer.pad_token_id,
                                dtype=input_ids_item.dtype,
                                device=input_ids_item.device,
                            ),
                        ]
                    )
                    padded_attention = torch.cat(
                        [
                            attention_mask_item,
                            torch.zeros(pad_length, dtype=attention_mask_item.dtype, device=attention_mask_item.device),
                        ]
                    )
                else:
                    padded_input = input_ids_item
                    padded_attention = attention_mask_item

                padded_input_ids.append(padded_input)
                padded_attention_masks.append(padded_attention)

            batch_input_ids_tensor = torch.stack(padded_input_ids)
            batch_attention_masks_tensor = torch.stack(padded_attention_masks)

            outputs = self.model.generate(
                input_ids=batch_input_ids_tensor,
                attention_mask=batch_attention_masks_tensor,
                max_length=128,
                pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                output_attentions=True,
                return_dict_in_generate=True,
            )

            for i, (generated_sequence, input_length, valid_tokens) in enumerate(
                zip(outputs.sequences, input_lengths, valid_tokens_list)
            ):
                generated_sequence_only = generated_sequence[input_length:]
                generated_text = tokenizer.decode(generated_sequence_only, skip_special_tokens=True)

                input_text = tokenizer.decode(batch_input_ids[i], skip_special_tokens=True)
                target_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)

                results.append(
                    {
                        "input_text": input_text,
                        "target_text": target_text,
                        "pred_text": generated_text,
                        "incorrect_flag": generated_text == target_text,
                    }
                )

                if generated_text == target_text:
                    total_correct += 1
                total_sample += 1

        df = pd.DataFrame(results)
        return total_correct / total_sample if total_sample > 0 else 0, df
