
import transformers.modeling_outputs
from weighted_dataset import WeightedDataset
from pruning import PrevPruningSampler, PrevStratifiedSampler, PrevRandomSubsetSampler
import torch.distributed as dist
from image_forward_overload import get_forward_function



import logging
import os
import sys
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Optional
from contextlib import redirect_stdout
from io import StringIO

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.transforms.v2 import (
    ToImage,
    ToDtype,
    CenterCrop,
    Compose,
    Lambda,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    RandAugment,
    RandomErasing,
    Identity,
)

import transformers
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    TimmWrapperImageProcessor,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from diffusers.training_utils import EMAModel


"""Fine-tuning a Transformers model for image classification"""

logger = logging.getLogger(__name__)




require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def pil_loader(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGB")


@dataclass
class CustomTrainingArguments(TrainingArguments):
    use_cutmixup: Optional[float] = field(
        default=0.0,
        metadata={
            "help": "whether to use cutmix/mixup"
        },
    )
    use_erasing: Optional[float] = field(
        default=0.5,
        metadata={
            "help": "whether to use RandomErasing"
        },
    )
    use_ema: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether to use Exponential Moving Average for the final model weights."
        },
    )
    ema_inv_gamma: Optional[float] = field(
        default=1.0,
        metadata={
            "help": "The inverse gamma value for the EMA decay."
        },
    )
    ema_power: Optional[float] = field(
        default=3/4,
        metadata={
            "help": "The power value for the EMA decay."
        },
    )
    ema_max_decay: Optional[float] = field(
        default=0.9999,
        metadata={
            "help": "The maximum decay magnitude for EMA."
        },
    )


@dataclass
class DataTrainingArguments:
    """Arguments for dataset and dataloading."""

    dataset_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
        },
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."})
    validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."})
    train_val_split: Optional[float] = field(
        default=0.15, metadata={"help": "Percent to split off of train for validation."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    image_column_name: str = field(
        default="image",
        metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
    )
    label_column_name: str = field(
        default="label",
        metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
    )
    sampling_tag: str = field(
        default="stratified",
        metadata={"help": "Sampling strategy for weighted training. Options: 'pruning', 'stratified', 'normal'"},
    )
    sampling_ratio: float = field(
        default=1.0,
        metadata={"help": "Ratio of the training set to sample for each epoch (for weighted/stratified/pruning sampling)."},
    )
    calc_stable_rank: Optional[int] = field(
        default = 0,
        metadata={"help": "whether to compute stable rank to dynamically allocate coreset size."},
    )
    corrupt_ratio: float = field(
        default=0.0,
        metadata={"help": "Ratio of training samples to corrupt with random labels for noise testing."},
    )
    

    def __post_init__(self):
        if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
            raise ValueError(
                "You must specify either a dataset name from the hub or a train and/or validation directory."
            )


@dataclass
class ModelArguments:
    """Arguments for model configuration and loading."""

    model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `hf auth login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
            )
        },
    )
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )
    load_pretrained_weights: bool = field(
        default=False,
        metadata={"help": "Whether to load pretrained weights or just model architecture (random initialization)."},
    )


def main():
    

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        
        
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    send_example_telemetry("run_image_classification", model_args, data_args)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.bf16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    set_seed(training_args.seed)

    if data_args.dataset_name is not None:
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
        )
    else:
        data_files = {}
        if data_args.train_dir is not None:
            data_files["train"] = os.path.join(data_args.train_dir, "**")
        if data_args.validation_dir is not None:
            data_files["validation"] = os.path.join(data_args.validation_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=model_args.cache_dir,
        )

    dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
    if data_args.image_column_name not in dataset_column_names:
        raise ValueError(
            f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--image_column_name` to the correct audio column - one of "
            f"{', '.join(dataset_column_names)}."
        )
    if data_args.label_column_name not in dataset_column_names:
        raise ValueError(
            f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
            "Make sure to set `--label_column_name` to the correct text column - one of "
            f"{', '.join(dataset_column_names)}."
        )

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example[data_args.label_column_name] for example in examples])
        
        
        sample_indices = torch.tensor([example.get("sample_idx", 0) for example in examples])
        weights = torch.tensor([example.get("weight", 1.0) for example in examples])
        
        return {
            "pixel_values": pixel_values, 
            "labels": labels,
            "sample_idx": sample_indices,
            "weight": weights
        }

    has_val = "validation" in dataset
    has_test = "test" in dataset
    data_args.train_val_split = None if (has_val or has_test) else data_args.train_val_split
    if has_test and not has_val:
        dataset['validation']=dataset['test']
    if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
        split = dataset["train"].train_test_split(data_args.train_val_split)
        dataset["train"] = split["train"]
        dataset["validation"] = split["test"]

    logger.info(f">>> Train set size: {len(dataset['train'])}, Validation set size: {len(dataset['validation'])}")

    labels = dataset["train"].features[data_args.label_column_name].names
    label2id, id2label = {}, {}
    for i, label in enumerate(labels):
        label2id[label] = str(i)
        id2label[str(i)] = label

    metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)

    def compute_metrics(p):
        """Compute accuracy."""
        return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

    if model_args.load_pretrained_weights:
        with redirect_stdout(StringIO()):
            config = AutoConfig.from_pretrained(
                model_args.config_name or model_args.model_name_or_path,
                num_labels=len(labels),
                label2id=label2id,
                id2label=id2label,
                finetuning_task="image-classification",
                cache_dir=model_args.cache_dir,
                revision=model_args.model_revision,
                token=model_args.token,
                trust_remote_code=model_args.trust_remote_code,
            )
            model = AutoModelForImageClassification.from_pretrained(
                model_args.model_name_or_path,
                from_tf=bool(".ckpt" in model_args.model_name_or_path),
                config=config,
                cache_dir=model_args.cache_dir,
                revision=model_args.model_revision,
                token=model_args.token,
                trust_remote_code=model_args.trust_remote_code,
                ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
            )
        logger.info("Loaded model with pretrained weights")
    else:
        with redirect_stdout(StringIO()):
            config = AutoConfig.from_pretrained(
                model_args.config_name or model_args.model_name_or_path,
                num_labels=len(labels),
                label2id=label2id,
                id2label=id2label,
                finetuning_task="image-classification",
                cache_dir=model_args.cache_dir,
                revision=model_args.model_revision,
                token=model_args.token,
                trust_remote_code=model_args.trust_remote_code,
            )
            model = AutoModelForImageClassification.from_config(
                config,
            )
        logger.info("Loaded model with random initialization (no pretrained weights)")

    if training_args.use_ema:
        ema_model = EMAModel(
            model.parameters(),
            decay=training_args.ema_max_decay,
            use_ema_warmup=True,
            inv_gamma=training_args.ema_inv_gamma,
            power=training_args.ema_power,
            model_cls=AutoModelForImageClassification,
            model_config=model.config,
        )

    



    image_processor = AutoImageProcessor.from_pretrained(
        model_args.image_processor_name or model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=model_args.token,
        trust_remote_code=model_args.trust_remote_code,
    )

    if isinstance(image_processor, TimmWrapperImageProcessor):
        _train_transforms = image_processor.train_transforms
        _val_transforms = image_processor.val_transforms
    else:
        if "shortest_edge" in image_processor.size:
            size = image_processor.size["shortest_edge"]
        else:
            size = (image_processor.size["height"], image_processor.size["width"])

        if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
            normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
        else:
            normalize = Lambda(lambda x: x)
        _train_transforms = Compose(
            [
                ToImage(),
                ToDtype(torch.uint8, scale=True),
                Resize(size), 
                RandAugment(2,28),
                RandomHorizontalFlip(),
                ToDtype(torch.float32, scale=True),
                normalize,
                RandomErasing(p=training_args.use_erasing) if training_args.use_erasing else Identity(),
            ]
        )
        _val_transforms = Compose(
            [
                ToImage(),
                ToDtype(torch.uint8, scale=True),
                Resize(size),
                CenterCrop(size),
                ToDtype(torch.float32, scale=True),
                normalize,
            ]
        )

    def train_transforms(example_batch):
        example_batch["pixel_values"] = [
            _train_transforms(pil_img) for pil_img in example_batch[data_args.image_column_name]
        ]
        return example_batch

    def val_transforms(example_batch):
        example_batch["pixel_values"] = [
            _val_transforms(pil_img) for pil_img in example_batch[data_args.image_column_name]
        ]
        return example_batch

    if training_args.do_train:
        if "train" not in dataset:
            raise ValueError("--do_train requires a train dataset")
        if data_args.max_train_samples is not None:
            dataset["train"] = (
                dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
            )
        if data_args.corrupt_ratio and data_args.corrupt_ratio > 0.0:
            num_train = len(dataset["train"])
            num_corrupt = int(num_train * data_args.corrupt_ratio)
            if num_corrupt > 0:
                rng = np.random.default_rng(training_args.seed)
                selected_indices = rng.choice(num_train, size=num_corrupt, replace=False)
                current_labels = dataset["train"][data_args.label_column_name]
                selected_labels = [int(current_labels[i]) for i in selected_indices]
                permuted_order = rng.permutation(num_corrupt)
                permuted_labels = [selected_labels[i] for i in permuted_order]
                idx_to_new_label = {int(idx): int(lbl) for idx, lbl in zip(selected_indices, permuted_labels)}

                def _corrupt_example(example, idx):
                    if idx in idx_to_new_label:
                        example[data_args.label_column_name] = idx_to_new_label[idx]
                    return example

                dataset["train"] = dataset["train"].map(_corrupt_example, with_indices=True)
                logger.info(f"Corrupted {num_corrupt}/{num_train} training labels (ratio={data_args.corrupt_ratio}).")
        
        dataset["train"].set_transform(train_transforms)
        
        
        weighted_train_dataset = WeightedDataset(dataset["train"])
        model.trainset = weighted_train_dataset

    if training_args.do_eval:
        if "validation" not in dataset:
            raise ValueError("--do_eval requires a validation dataset")
        if data_args.max_eval_samples is not None:
            dataset["validation"] = (
                dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
            )
        
        dataset["validation"].set_transform(val_transforms)

        if has_test:
            dataset['test'].set_transform(val_transforms)

    def sync_scores_across_ranks(trainset):
        if not dist.is_initialized():
            return
        scores_tensor = torch.tensor(trainset.scores, dtype=torch.float32, device='cuda')
        dist.all_reduce(scores_tensor, op=dist.ReduceOp.AVG)
        trainset.scores = scores_tensor.cpu().numpy()

    class SyncScoresCallback(transformers.TrainerCallback):
        def on_evaluate(self, args, state, control, **kwargs):
            model_in_trainer = kwargs.get('model', None)
            if hasattr(model_in_trainer, 'trainset') and hasattr(model_in_trainer.trainset, 'scores'):
                sync_scores_across_ranks(model_in_trainer.trainset)

    class FinalLogCallback(transformers.TrainerCallback):
        def __init__(self, num_epochs):
            self.num_epochs = num_epochs
        def on_epoch_end(self, args, state, control, **kwargs):
            if state.epoch >= self.num_epochs:
                print(">>> FinalLogCallback action")
                control.should_log = True
                control.should_evaluate = True
                control.should_save = True

    class EvalSubTrainCallback(transformers.TrainerCallback):
        def __init__(self, trainer) -> None:
            super().__init__()
            self._trainer = trainer
        
        def on_epoch_end(self, args, state, control, **kwargs):
            if control.should_evaluate:
                control_copy = deepcopy(control)
                eval_dataset_length = len(self._trainer.eval_dataset) if self._trainer.eval_dataset else 1000
                sample_size = min(eval_dataset_length, len(self._trainer.train_dataset))
                
                random_indices = torch.randperm(len(self._trainer.train_dataset))[:sample_size]
                train_subset = torch.utils.data.Subset(self._trainer.train_dataset, random_indices)
                self._trainer.evaluate(eval_dataset=train_subset, metric_key_prefix="train")
                return control_copy


    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=weighted_train_dataset if training_args.do_train else None,
        eval_dataset=dataset["validation"] if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        processing_class=image_processor,
        data_collator=collate_fn,
        callbacks=[SyncScoresCallback(), FinalLogCallback(training_args.num_train_epochs)],
    )
    trainer.add_callback(EvalSubTrainCallback(trainer)) 

    if training_args.use_ema:
        class EMACallback(transformers.TrainerCallback):
            def __init__(self, ema_model):
                self.ema_model = ema_model
            
            def on_train_begin(self, args, state, control, model=None, **kwargs):
                if model is not None:
                    self.ema_model.to(model.device)
                    self.ema_model.store(model.parameters())
            
            def on_epoch_begin(self, args, state, control, model=None, **kwargs):
                if model is not None:
                    self.ema_model.restore(model.parameters())
            
            def on_step_end(self, args, state, control, model=None, **kwargs):
                if model is not None:
                    self.ema_model.step(model.parameters())
            
            def on_evaluate(self, args, state, control, model=None, **kwargs):
                if model is not None:
                    self.ema_model.store(model.parameters())
                    self.ema_model.copy_to(model.parameters())
            
            def on_save(self, args, state, control, model=None, **kwargs):
                if model is not None:
                    self.ema_model.store(model.parameters())
                    self.ema_model.copy_to(model.parameters())
        
        trainer.add_callback(EMACallback(ema_model))
        trainer.ema_model = ema_model

    if data_args.calc_stable_rank: 
        from trainer_sr import StableRankCallback, PEFTEigenvalueCalculator
        stable_rank_callback = StableRankCallback(
            calculator_class=PEFTEigenvalueCalculator,
            trainer=trainer
        )
        trainer.add_callback(stable_rank_callback)


    from image_forward_overload import compute_loss, LabelSmoother
    trainer.label_smoother = LabelSmoother(epsilon=training_args.label_smoothing_factor)
    trainer.compute_loss = compute_loss.__get__(trainer, trainer.__class__)
    print(f">>> Compute loss method replaced successfully <<<")

    
    def _pruning_train_sampler(self, sampler=None):
        return PrevPruningSampler(self.train_dataset, ratio=data_args.sampling_ratio, num_epochs=training_args.num_train_epochs, delta=1.0)
    
    def _distributed_train_sampler(self, sampler=None):
        return PrevRandomSubsetSampler(self.train_dataset, ratio=data_args.sampling_ratio, num_epochs=training_args.num_train_epochs, delta=1.0)
    
    def _get_stratified_train_sampler(self, sampler=None):
        return PrevStratifiedSampler(self.train_dataset, trainer=trainer, ratio=data_args.sampling_ratio, num_epochs=training_args.num_train_epochs, c=1.0, delta=1.0)

    sampling_tag = data_args.sampling_tag  

    if sampling_tag == 'pruning':
        trainer._get_train_sampler = _pruning_train_sampler.__get__(trainer, trainer.__class__)
        trainer.create_model_card(dataset_tags="pruning")
    elif sampling_tag == 'stratified':
        trainer._get_train_sampler = _get_stratified_train_sampler.__get__(trainer, trainer.__class__)
        trainer.create_model_card(dataset_tags="stratified")
    else:
        trainer._get_train_sampler = _distributed_train_sampler.__get__(trainer, trainer.__class__)
        trainer.create_model_card(dataset_tags="normal")

    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        
        if training_args.use_ema and hasattr(trainer, 'ema_model'):
            trainer.ema_model.store(trainer.model.parameters())
            trainer.ema_model.copy_to(trainer.model.parameters())
            trainer.save_model()
            trainer.ema_model.restore(trainer.model.parameters())
        else:
            trainer.save_model()
            
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()

    if training_args.do_eval:
        if training_args.use_ema and hasattr(trainer, 'ema_model'):
            trainer.ema_model.store(trainer.model.parameters())
            trainer.ema_model.copy_to(trainer.model.parameters())
            metrics = trainer.evaluate()
            trainer.ema_model.restore(trainer.model.parameters())
        else:
            metrics = trainer.evaluate()
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

        if has_test:
            logger.info("*** Running test evaluation ***")
            if training_args.use_ema and hasattr(trainer, 'ema_model'):
                trainer.ema_model.store(trainer.model.parameters())
                trainer.ema_model.copy_to(trainer.model.parameters())
                test_metrics = trainer.evaluate(eval_dataset=dataset['test'], metric_key_prefix="test")
                trainer.ema_model.restore(trainer.model.parameters())
            else:
                test_metrics = trainer.evaluate(eval_dataset=dataset['test'], metric_key_prefix="test")
            trainer.log_metrics("test", test_metrics)
            trainer.save_metrics("test", test_metrics)

    kwargs = {
        "finetuned_from": model_args.model_name_or_path,
        "tasks": "image-classification",
        "dataset": data_args.dataset_name,
        "tags": ["image-classification", "vision"],
    }
    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)


if __name__ == "__main__":
    main()