

import os
import json
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.transforms.v2 import (
    ToImage,
    ToDtype,
    CenterCrop,
    Compose,
    Lambda,
    Normalize,
    Resize,
)
import transformers
from transformers import (
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint

import sys
sys.path.append('.')
from weighted_dataset import WeightedDataset
from pruning import PrevStratifiedSampler, PrevRandomSubsetSampler
from image_forward_overload import get_forward_function
import torch.distributed as dist

def pil_loader(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGB")


def collate_fn(examples):
    pixel_values = []
    for example in examples:
        if isinstance(example["pixel_values"], torch.Tensor):
            pixel_values.append(example["pixel_values"])
        else:
            pixel_values.append(torch.tensor(example["pixel_values"]))
    
    pixel_values = torch.stack(pixel_values)
    labels = torch.tensor([example["label"] for example in examples])
    
    
    sample_indices = torch.tensor([example.get("sample_idx", 0) for example in examples])
    weights = torch.tensor([example.get("weight", 1.0) for example in examples])
    
    return {
        "pixel_values": pixel_values, 
        "labels": labels,
        "sample_idx": sample_indices,
        "weight": weights
    }

def compute_metrics(p):
    import evaluate
    metric = evaluate.load("accuracy")
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


class LossCollectorCallback(transformers.TrainerCallback):
    def __init__(self, skip_epochs=3):
        self.epoch_losses = []
        self.current_epoch_losses = []
        self.skip_epochs = skip_epochs
        self.completed_epochs = 0
    
    def on_step_end(self, args, state, control, **kwargs):
        if hasattr(state, 'log_history') and state.log_history:
            latest_log = state.log_history[-1]
            if 'loss' in latest_log:
                self.current_epoch_losses.append(latest_log['loss'])
    
    def on_epoch_end(self, args, state, control, **kwargs):
        if self.current_epoch_losses:
            avg_loss = np.mean(self.current_epoch_losses)
            self.completed_epochs += 1
            
            if self.completed_epochs > self.skip_epochs:
                self.epoch_losses.append(avg_loss)
                print(f"Epoch {state.epoch} average training loss: {avg_loss:.4f}")
            else:
                print(f"Epoch {state.epoch} average training loss: {avg_loss:.4f} (skipped for analysis)")
            
            self.current_epoch_losses = []


def evaluate_sampling_strategy(trainer, dataset, sampling_strategy, sample_ratio, num_trials=10, skip_epochs=3):
    weighted_dataset = WeightedDataset(dataset)
    trainer.model.trainset = weighted_dataset
    trainer.train_dataset = weighted_dataset
    
    def _random_train_sampler(self, sampler=None):
        return PrevRandomSubsetSampler(
            self.train_dataset, 
            ratio=sample_ratio, 
            num_epochs=num_trials,
            delta=1.0
        )
    
    def _stratified_train_sampler(self, sampler=None):
        return PrevStratifiedSampler(
            self.train_dataset, 
            ratio=sample_ratio, 
            num_epochs=num_trials,
            delta=1.0, 
            c=1.0
        )
    
    if sampling_strategy == 'random':
        trainer._get_train_sampler = _random_train_sampler.__get__(trainer, trainer.__class__)
    else:  
        trainer._get_train_sampler = _stratified_train_sampler.__get__(trainer, trainer.__class__)
    
    loss_collector = LossCollectorCallback(skip_epochs=skip_epochs)
    trainer.add_callback(loss_collector)
    
    trainer.train()
    
    losses = loss_collector.epoch_losses
    
    trainer.remove_callback(loss_collector)
    
    if trainer.is_world_process_zero():
        print(f"  Collected {len(losses)} epoch losses: {[f'{loss:.4f}' for loss in losses]}")
    
    return losses


def main():
    def is_main_process():
        if torch.distributed.is_initialized():
            return torch.distributed.get_rank() == 0
        return True
    
    if is_main_process():
        print("Starting Stratified Loss Variance Analysis...")
    
    BASE_MODEL = "microsoft/resnet-18"
    DATASET_NAME = "cifar10"
    SAMPLE_RATIOS = [0.3,]
    NUM_TRIALS = 10
    SKIP_EPOCHS = 3
    
    set_seed(1337)
    
    if is_main_process():
        print("Loading dataset...")
    dataset = load_dataset(DATASET_NAME, cache_dir=None)
    
    if "validation" not in dataset:
        dataset['validation'] = dataset['test']
    
    if is_main_process():
        print(f"Train set size: {len(dataset['train'])}, Validation set size: {len(dataset['validation'])}")
    
    labels = dataset["train"].features["label"].names
    label2id, id2label = {}, {}
    for i, label in enumerate(labels):
        label2id[label] = str(i)
        id2label[str(i)] = label
    
    if is_main_process():
        print(f"Dataset columns: {dataset['train'].column_names}")
        print(f"Labels: {labels}")
    
    if is_main_process():
        print("Loading model...")
    config = AutoConfig.from_pretrained(
        BASE_MODEL,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="image-classification",
    )
    
    model = AutoModelForImageClassification.from_pretrained(
        BASE_MODEL,
        config=config,
        ignore_mismatched_sizes=True,
    )
    
    forward_func = get_forward_function(model)
    model.forward = forward_func.__get__(model, model.__class__)
    
    image_processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
    
    if "shortest_edge" in image_processor.size:
        size = image_processor.size["shortest_edge"]
    else:
        size = (image_processor.size["height"], image_processor.size["width"])
    
    if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
        normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
    else:
        normalize = Lambda(lambda x: x)
    
    val_transforms = Compose([
        ToImage(),
        ToDtype(torch.uint8, scale=True),
        Resize(size),
        CenterCrop(size),
        ToDtype(torch.float32, scale=True),
        normalize,
    ])
    
    def transforms(example_batch):
        example_batch["pixel_values"] = [
            val_transforms(pil_img) for pil_img in example_batch["img"]
        ]
        return example_batch
    
    dataset["train"].set_transform(transforms)
    
    training_args = TrainingArguments(
        output_dir="./temp_output",
        per_device_train_batch_size=1024,
        per_device_eval_batch_size=1024,
        dataloader_num_workers=16,
        dataloader_pin_memory=True,
        dataloader_prefetch_factor=2,
        remove_unused_columns=False,
        report_to=None,
        learning_rate=0.0,
        num_train_epochs=NUM_TRIALS,
        logging_steps=1,
        save_steps=10000,
        eval_steps=None,
        save_total_limit=0,
        load_best_model_at_end=False,
        metric_for_best_model=None,
        greater_is_better=None,
    )

    
    def sync_scores_across_ranks(trainset):
        if not dist.is_initialized():
            return
        scores_tensor = torch.tensor(trainset.scores, dtype=torch.float32, device='cuda')
        dist.all_reduce(scores_tensor, op=dist.ReduceOp.AVG)
        trainset.scores = scores_tensor.cpu().numpy()

    class SyncScoresCallback(transformers.TrainerCallback):
        def on_epoch_end(self, args, state, control, **kwargs):
            model_in_trainer = kwargs.get('model', None)
            if hasattr(model_in_trainer, 'trainset') and hasattr(model_in_trainer.trainset, 'scores'):
                sync_scores_across_ranks(model_in_trainer.trainset)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
        callbacks=[SyncScoresCallback()]
    )
    
    results = {
        "dataset": DATASET_NAME,
        "model": BASE_MODEL,
        "num_trials": NUM_TRIALS,
        "skip_epochs": SKIP_EPOCHS,
        "sample_ratios": SAMPLE_RATIOS,
        "results": {}
    }
    
    for ratio in SAMPLE_RATIOS:
        if trainer.is_world_process_zero():
            print(f"\nEvaluating sample ratio: {ratio}")
        results["results"][str(ratio)] = {}
        
        if trainer.is_world_process_zero():
            print(f"  Stratified sampling...")
        stratified_losses = evaluate_sampling_strategy(
            trainer, dataset["train"], "stratified", ratio, NUM_TRIALS, SKIP_EPOCHS
        )


        if trainer.is_world_process_zero():
            print(f"  Random sampling...")
        random_losses = evaluate_sampling_strategy(
            trainer, dataset["train"], "random", ratio, NUM_TRIALS, SKIP_EPOCHS
        )

        
        random_mean = np.mean(random_losses)
        random_std = np.std(random_losses)
        stratified_mean = np.mean(stratified_losses)
        stratified_std = np.std(stratified_losses)
        
        if trainer.is_world_process_zero():
            print(f"  Random - Mean: {random_mean:.4f}, Std: {random_std:.4f}")
            print(f"  Stratified - Mean: {stratified_mean:.4f}, Std: {stratified_std:.4f}")
        
        results["results"][str(ratio)]["random"] = {
            "losses": random_losses,
            "mean": float(random_mean),
            "std": float(random_std)
        }
        results["results"][str(ratio)]["stratified"] = {
            "losses": stratified_losses,
            "mean": float(stratified_mean),
            "std": float(stratified_std)
        }
    
    if trainer.is_world_process_zero():
        output_file = "ablation/stratified_loss_results.json"
        output_dir = os.path.dirname(output_file)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"\nResults saved to: {output_file}")
        print("Analysis completed!")


if __name__ == "__main__":
    main()


