import random
import os
import numpy as np
import math
import wandb
from copy import deepcopy
import torch
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from torchvision.transforms import ToPILImage

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

def cosine_scheduler(base_value, final_value, max_steps, warmup_steps=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    if warmup_steps > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_steps)

    iters = np.arange(max_steps - warmup_steps)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    return schedule

def exponential_decay_scheduler(base_value, max_steps, decay_constant=None, half_ratio=0.25, patience_steps=None):
    assert (decay_constant is not None) != (half_ratio != None)
    if half_ratio is not None:
        decay_constant = math.log(half_ratio * max_steps)
    patience_schedule = np.ones((patience_steps,)) * base_value
    iters = np.arange(max_steps - patience_steps)

    decay_schedule = np.exp(iters * -decay_constant)

class EMATeacherCallback(TrainerCallback):
    def __init__(self, resume=False) -> None:
        super().__init__()
        self.resume = resume
    
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        if getattr(model.teacher, "ema_teacher", None) is None or not self.resume:
            model.teacher.create_ema_teacher()
        
        model.teacher.create_ema_schedule(state.max_steps)

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        model.teacher.set_ema_schedule_step(state.global_step)
    
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        model.teacher.step_ema_teacher()

class CrossAttentionMaskingNoiseCallback(TrainerCallback):
    def __init__(self, resume=False, decay_type="cosine", schedule_length_ratio=None) -> None:
        super().__init__()
        self.resume = resume
        self.decay_type = decay_type
        self.schedule_length_ratio = schedule_length_ratio
        
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        if state.is_world_process_zero:
            self.run = wandb.run
        assert model.teacher.student_model.encoder.cross_attention.attention.attention.train_noising_function is not None
        self.target = model.teacher.student_model.encoder.cross_attention.attention.attention.train_noising_function.transforms[0]

        schedule_length = round(state.max_steps * self.schedule_length_ratio) if self.schedule_length_ratio is not None else state.max_steps
        
        if self.decay_type == "exponential":
            self.schedule = exponential_decay_scheduler(1.0, schedule_length)
        elif self.decay_type == "cosine":
            self.schedule = cosine_scheduler(1.0, 0.0, schedule_length)

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        noise_scale = self.schedule[state.global_step] if state.global_step < len(self.schedule) else 0.0
        self.target.set_noise_scale(noise_scale)
        if state.is_world_process_zero and self.run is not None:
            self.run.log({"masking_noise_scale": noise_scale})
    
class EMADistanceCallback(TrainerCallback):
    def __init__(self, log_step_interval=1000) -> None:
        super().__init__()
        self.log_step_interval = log_step_interval
    
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self.run = wandb.run if wandb.run is not None else wandb.init()
    
    @torch.no_grad()
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, **kwargs):
        if state.global_step % self.log_step_interval or not state.is_world_process_zero:
            return
        assert getattr(model.teacher, "ema_teacher") is not None
        student_model, ema_model = deepcopy(model.teacher.student_model), deepcopy(model.teacher.ema_teacher)
        diff, num_params = 0, 0
        for student_param, ema_param in zip(student_model.parameters(), ema_model.parameters()):
            diff += (student_param - ema_param).abs().sum()
            num_params += student_param.numel()
        self.run.log({"EMA Distance": (diff / num_params).item()}, commit=False)
    
class KillCallback(TrainerCallback):
    def __init__(self, num_steps):
        super().__init__()
        self.num_steps = num_steps
        
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step > self.num_steps:
            control.should_training_stop = True

class WandbTextReconstructionVisualizer(TrainerCallback):
    def __init__(self, log_step_interval=1000, num_samples=5, random_sample=False, fill_unmasked=True, repeats=2, mask_tracking=[38, ]):
        super().__init__()
        self.log_step_interval = log_step_interval
        self.num_samples = num_samples
        self.random_sample = random_sample
        self.fill_unmasked = fill_unmasked
        self.table = wandb.Table(columns=["step", "original_text", "transformed_text", "reconstruction_predictions", "true_mask_ratio"])
        self.repeats = repeats
        self.mask_tracking = mask_tracking

    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self.run = wandb.run if wandb.run is not None else wandb.init()
    
    @torch.no_grad()
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, model=None, tokenizer=None, **kwargs):
        if state.global_step % self.log_step_interval or not state.is_world_process_zero:
            return
        # Weird hack required due to current wandb table issues
        self.table = wandb.Table(
            columns=self.table.columns, data=self.table.data
        )
        if getattr(model.teacher, "reconstruction_decoder", None) is None:
            print("Cannot visualize reconstructions without reconstruction decoder")
            return
        # deepcopy to avoid possibly propogating any changes
        visualization_model = deepcopy(model)
        visualization_model.train()
        dataset = kwargs["train_dataloader"].dataset
        if self.random_sample:
            samples = [dataset[i] for i in random.sample(len(dataset), k=self.num_samples)]
        else:
            samples = [dataset[i] for i in range(self.num_samples)]
        
        true_mask_ratios = []
        counts = {i:0 for i in self.mask_tracking}
        total_masked_tokens = 0
        for sample in samples:
            for _ in range(self.repeats):
                with torch.no_grad():
                    inputs = torch.LongTensor(sample['input_ids']).unsqueeze(0).to(visualization_model.device)
                    attention_mask = deepcopy(torch.Tensor(sample['attention_mask']).unsqueeze(0).to(model.device))
                    full_model_outputs = visualization_model(
                        inputs=inputs,
                        attention_mask=attention_mask,
                        output_attentions=True,
                    )

                    decoder_outputs = visualization_model.teacher.reconstruction_decoder(full_model_outputs)
                
                inputs = inputs.squeeze()
                decoder_predictions = torch.argmax(decoder_outputs.logits, dim=-1).squeeze()
                
                non_pad_indices = torch.Tensor(sample['attention_mask']).nonzero().squeeze()
                text = tokenizer.decode(inputs[non_pad_indices])

                transformed_inputs = deepcopy(inputs)
                masked_bool = (full_model_outputs.cross_attentions[0].reshape((-1, full_model_outputs.cross_attentions[0].shape[-1])).sum(dim=0) == 0)
                masked_indices = masked_bool.nonzero().squeeze()
             
                if self.mask_tracking:
                    for token_id in transformed_inputs[masked_indices]:
                        total_masked_tokens += 1
                        if token_id.item() in self.mask_tracking:
                            counts[token_id.item()] += 1
                
                transformed_inputs[masked_indices] = tokenizer._convert_token_to_id("#")
                transformed_text = tokenizer.decode(transformed_inputs[non_pad_indices])
                
                true_mask_ratio = (attention_mask.long() & masked_bool).sum() / attention_mask.sum()
                true_mask_ratios.append(true_mask_ratio)
                if self.fill_unmasked:
                    unmasked_indices = (masked_bool == 0).nonzero().squeeze()
                    decoder_predictions[unmasked_indices] = inputs[unmasked_indices]
                
                decoder_predictions = tokenizer.decode(decoder_predictions[non_pad_indices])
                self.table.add_data(state.global_step, text, transformed_text, decoder_predictions, true_mask_ratio)
        logs = {"Text Reconstruction Visualizations": self.table, "True Mask Ratio": sum(true_mask_ratios) / len(true_mask_ratios)}
        logs.update({f"mask_char_ratio_\"{tokenizer._convert_id_to_token(k)}\"": v / total_masked_tokens for k, v in counts.items()})
        self.run.log(logs, commit=False)

class WandbImageReconstructionVisualizer(TrainerCallback):
    def __init__(
        self,
        log_step_interval=1000,
        num_samples=5,
        random_sample=False,
        fill_unmasked=True,
        repeats=2,
        image_height=224,
        image_width=224,
        image_mean=None,
        image_std=None,
        replace_color=torch.Tensor([0, 1, 1]),
    ):
        super().__init__()
        self.log_step_interval = log_step_interval
        self.num_samples = num_samples
        self.random_sample = random_sample
        self.fill_unmasked = fill_unmasked
        self.table = wandb.Table(columns=["step", "original_image", "transformed_image", "reconstruction_predictions", "true_mask_ratio"])
        self.repeats = repeats
        self.image_height = image_height
        self.image_width = image_width
        self.image_mean = image_mean
        self.image_std = image_std
        self.replace_color = replace_color

    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self.run = wandb.run if wandb.run is not None else wandb.init()
    
    @torch.no_grad()
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, model=None, **kwargs):
        if state.global_step % self.log_step_interval or not state.is_world_process_zero:
            return
        # Weird hack required due to current wandb table issues
        self.table = wandb.Table(
            columns=self.table.columns, data=self.table.data
        )
        if getattr(model.teacher, "reconstruction_decoder", None) is None:
            print("Cannot visualize reconstructions without reconstruction decoder")
            return
        # deepcopy to avoid possibly propogating any changes
        visualization_model = deepcopy(model)
        visualization_model.train()
        dataset = kwargs["train_dataloader"].dataset
        if self.random_sample:
            samples = [dataset[i] for i in random.sample(len(dataset), k=self.num_samples)]
        else:
            samples = [dataset[i] for i in range(self.num_samples)]
        
        true_mask_ratios = []
        for sample in samples:
            for _ in range(self.repeats):
                with torch.no_grad():
                    inputs = sample['inputs'].unsqueeze(0).to(visualization_model.device)
                    attention_mask = deepcopy(sample['attention_mask']).unsqueeze(0).to(model.device)
                    full_model_outputs = visualization_model(
                        inputs=inputs,
                        attention_mask=attention_mask,
                        output_attentions=True,
                    )

                    decoder_outputs = visualization_model.teacher.reconstruction_decoder(full_model_outputs)
                
                inputs = inputs.squeeze()
                if self.image_mean is not None:
                    inputs = inputs * self.image_std + self.image_mean

                original_image = ToPILImage()(inputs.reshape((self.image_height, self.image_width, inputs.shape[-1])).permute(2, 1, 0))
                transformed_inputs = deepcopy(inputs)
                masked_bool = (full_model_outputs.cross_attentions[0].reshape((-1, full_model_outputs.cross_attentions[0].shape[-1])).sum(dim=0) == 0)
                masked_indices = masked_bool.nonzero().squeeze()
                
                transformed_inputs[masked_indices] = self.replace_color.to(transformed_inputs.device)
                transformed_inputs = ToPILImage()(transformed_inputs.reshape((self.image_height, self.image_width, inputs.shape[-1])).permute(2, 1, 0))
                
                true_mask_ratio = masked_bool.sum() / masked_bool.numel()
                true_mask_ratios.append(true_mask_ratio)
                decoder_predictions = decoder_outputs["logits"].squeeze()
                if self.image_mean is not None:
                    decoder_predictions = decoder_predictions * self.image_std + self.image_mean
                
                if self.fill_unmasked:
                    unmasked_indices = (masked_bool == 0).nonzero().squeeze()
                    decoder_predictions[unmasked_indices] = inputs[unmasked_indices]

                decoder_predictions = ToPILImage()(decoder_predictions.reshape((self.image_height, self.image_width, inputs.shape[-1])).permute(2, 1, 0))
                self.table.add_data(state.global_step, wandb.Image(original_image), wandb.Image(transformed_inputs), wandb.Image(decoder_predictions), true_mask_ratio)
            
        logs = {"Image Reconstruction Visualizations": self.table, "True Mask Ratio": sum(true_mask_ratios) / len(true_mask_ratios)}
        self.run.log(logs, commit=False)

class ImageLoggingCallback(TrainerCallback):
    def __init__(
        self,
        image_mean,
        image_std,
        num_images=None,
        image_height=224,
        image_width=224
    ):
        super().__init__()
        self.image_mean = image_mean
        self.image_std = image_std
        self.num_images = num_images
        self.image_height = image_height
        self.image_width = image_width
        self.train_done = False
        self.eval_done = False
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero and not self.train_done:
            run = wandb.run if wandb.run is not None else wandb.init()

            train_table = wandb.Table(columns=["image"])
            train_loader = kwargs["train_dataloader"]
            for train_batch in train_loader: break

            for idx in range(len(train_batch["inputs"])):
                train_image = ToPILImage()((train_batch["inputs"][idx].view((self.image_height, self.image_width, 3)) * self.image_std + self.image_mean).permute(2, 1, 0))
                train_table.add_data(wandb.Image(train_image))
            run.log({"Train Image Log": train_table}, commit=False)
            self.train_done = True
    
    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero and not self.eval_done:
            run = wandb.run if wandb.run is not None else wandb.init()
            eval_table = wandb.Table(columns=["image"])
            eval_loader = kwargs["eval_dataloader"]
            for eval_batch in eval_loader: break

            for idx in range(len(eval_batch["inputs"])):
                eval_image = ToPILImage()((eval_batch["inputs"][idx].view((self.image_height, self.image_width, 3)) * self.image_std + self.image_mean).permute(2, 1, 0))
                eval_table.add_data(wandb.Image(eval_image))
            run.log({"Evaluation Image Log": eval_table}, commit=False)
            self.eval_done = True