import os
import torch
import lightning as L
from typing import List, Sequence
from datasets import Dataset
from datasets import load_dataset
from tqdm import tqdm
import csv
import time
from omegaconf import DictConfig
import hydra

from .classifier import CLIPZeroShotClassifier

DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')

def load_adv_text_dataset(path: str | Sequence[str]) -> Dataset:
    """
    Load adversarial text dataset from a CSV file. 'adv_text' is expected to be a column in the CSV.
    """
    dataset = load_dataset("csv", data_files=path)
    return dataset['train']

def load_human_text_dataset(path: str | Sequence[str], classes: Sequence[str]) -> Dataset:
    """
    Load human text dataset from a json file. 'prompt' is expected to be a column in the JSON.
    """
    import json
    with open(path, "r") as f:
        data = json.load(f)
    class_names = list(classes)
    class_to_label = {classname: idx for idx, classname in enumerate(class_names)}
    dataset = []
    for classname, data_list in data.items():
        if classname not in class_to_label:
            continue
        for item in data_list:
            item['classname'] = classname
            item['label'] = class_to_label[classname]
            item['j'] = item['index']
            dataset.append(item)
    dataset = Dataset.from_list(dataset)
    return dataset

def check_classname(text, classname):
    if classname in text:
        return True
    else:
        return False 

def save_pil_image(image, j, classname, successful, save_dir):

    file_path = os.path.join(save_dir, classname, successful)
    
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    save_path = os.path.join(file_path, f"{j}.png")

    image.save(save_path)

def default_collate_fn(batch):
    # Collect each field into a list
    batch_out = {}
    for key in batch[0].keys():
        batch_out[key] = [item[key] for item in batch]
    return batch_out

def attack(
    model, 
    dataset_type: str,
    data_path: str,
    classes: List[str],
    templates: List[str],
    img_size: int,
    batch_size: int = 20,
    model_name: str = 'ViT-B/16',
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
    save: bool = False,
    save_dir: str = os.path.join(DATA_DIR, '0'),
    save_name: str = 'adv_text_attack_results',
    start: int = 0,
    end: int = None,
    **kwargs
    ):
    """
    Required dataset format:
    {
        'text': List[str],  # original text
        'label': int,  # label
        'classname': str,  # class name
        'j': int,  # j value
    }
    """
    if dataset_type == 'adv':
        dataset = load_adv_text_dataset(data_path)
        dataset = dataset.rename_column('adv_text', 'text')
    elif dataset_type == 'human':
        dataset = load_human_text_dataset(data_path, classes)
        dataset = dataset.rename_column('prompt', 'text')
        print(dataset[0])

    # Dataset slicing
    dataset_size = len(dataset)
    if end is None:
        end = dataset_size
    
    # Ensure parameter validity
    start = max(0, start)
    end = min(dataset_size, end)
    
    if start >= end:
        print(f"Warning: start ({start}) >= end ({end}). No data to process.")
        return
    
    # Slice dataset
    dataset = dataset.select(range(start, end))
    print(f"Processing dataset from index {start} to {end-1} (total: {len(dataset)} samples)")
    
    model.to(device)
    classifier = CLIPZeroShotClassifier(classes, templates, img_size, model_name, device)
        
    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_img_dir = os.path.join(save_dir, 'images')
    
    # Prepare CSV file
    attack_res_csv = os.path.join(save_dir, f'{save_name}.csv')
    csv_exists = os.path.exists(attack_res_csv)
    
    # Statistics variables
    total_processed = 0
    successful_count = 0
    failed_count = 0
    error_count = 0
    
    # New: statistics for each class
    class_stats = {}  # {classname: {'total': int, 'success': int, 'failed': int}}
    
    # Write CSV header (if file does not exist)
    if not csv_exists:
        with open(attack_res_csv, 'w', newline='', encoding='utf-8') as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(['text', 'successful', 'j', 'prediction', 'classname', 'gen_time'])

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=default_collate_fn,
    )

    # Process data in batches
    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Processing batches")):
    # Extract batch data - ensure all are list format
        texts = batch_data['text'] if isinstance(batch_data['text'], list) else batch_data['text'].tolist()
        labels = batch_data['label'] if isinstance(batch_data['label'], list) else batch_data['label'].tolist()
        classnames = batch_data['classname'] if isinstance(batch_data['classname'], list) else batch_data['classname'].tolist()
        js = batch_data['j'] if isinstance(batch_data['j'], list) else batch_data['j'].tolist()
        
    # Use model's batch prediction function
        batch_gen_start = time.time()
        batch_images = model.generate(
            texts,
            num_images_per_prompt=1,
            output_type="pil",
            height=img_size,
            width=img_size,
            classnames=classnames,
            **kwargs
        ).images
        batch_gen_end = time.time()
        avg_gen_time = batch_gen_end - batch_gen_start

        
    # Batch process images
        batch_image_numpy = model.image_processor.pil_to_numpy(batch_images)
        batch_image_tensor = model.image_processor.numpy_to_pt(batch_image_numpy)
        
    # Batch classification
        batch_logits = classifier.get_logits(batch_image_tensor)
        batch_predictions = batch_logits.argmax(dim=1)
        
    # Process results for each sample
        for idx, (text, label, classname, j, image, prediction) in enumerate(
            zip(texts, labels, classnames, js, batch_images, batch_predictions)
        ):
            # Initialize class statistics (if not exist)
            if classname not in class_stats:
                class_stats[classname] = {'total': 0, 'success': 0, 'failed': 0}

            # Determine if attack is successful
            if isinstance(label, torch.Tensor):
                label = label.item()
            elif isinstance(label, (list, tuple)):
                label = label[0]
            
            if isinstance(prediction, torch.Tensor):
                prediction = prediction.item()
            
            if prediction != label:
                successful = 'success'
                successful_count += 1
                class_stats[classname]['success'] += 1
            else:
                successful = 'failed'
                failed_count += 1
                class_stats[classname]['failed'] += 1
            
            # Update class total
            class_stats[classname]['total'] += 1
            
            prediction_str = classes[prediction] if isinstance(prediction, int) else prediction

            gen_time = avg_gen_time
            
            # Save image
            if save:
                save_pil_image(image, str(j), classname, successful, save_img_dir)
            
            # Write to CSV immediately
            write_single_result_to_csv([text, successful, j, prediction_str, classname, f"{gen_time:.5f}"], attack_res_csv)

            total_processed += 1
            
            # Print progress every 100 samples
            if total_processed % 100 == 0:
                current_success_rate = successful_count / total_processed * 100
                print(f"Processed {total_processed} samples, Success rate: {current_success_rate:.2f}%")
    
    # Print final statistics (including per-class info)
    print_final_statistics(total_processed, successful_count, failed_count, error_count, class_stats)
    print(f"Attack results CSV saved to {attack_res_csv}")

def write_single_result_to_csv(result, csv_file):
    """Write a single result to CSV"""
    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(result)

def print_final_statistics(total, successful, failed, errors, class_stats=None):
    """Print final attack statistics, including details for each class"""
    if total == 0:
        print("No results to display")
        return
        
    print(f"\n=== Final Attack Statistics ===")
    print(f"Total samples: {total}")
    print(f"Successful attacks: {successful} ({successful/total*100:.2f}%)")
    print(f"Failed attacks: {failed} ({failed/total*100:.2f}%)")
    if errors > 0:
        print(f"Errors: {errors} ({errors/total*100:.2f}%)")
    if total > errors:
        print(f"Overall success rate: {successful/(total-errors)*100:.2f}%")
    else:
        print(f"Overall success rate: 0.00%")
    
    # Print attack statistics for each class
    if class_stats:
        print(f"\n=== Per-Class Attack Statistics ===")
        successful_classes = 0
        total_classes = len(class_stats)
        
    # Sort by class name
        sorted_classes = sorted(class_stats.items())
        
        for classname, stats in sorted_classes:
            success_rate = stats['success'] / stats['total'] * 100 if stats['total'] > 0 else 0
            class_status = "SUCCESS" if stats['success'] > 0 else "FAILED"
            
            if stats['success'] > 0:
                successful_classes += 1
            
            # print(f"{classname:20s} | Total: {stats['total']:4d} | Success: {stats['success']:4d} | "
            #       f"Failed: {stats['failed']:4d} | Rate: {success_rate:6.2f}% | Status: {class_status}")
        
    print(f"\n=== Class-Level Summary ===")
    print(f"Total classes: {total_classes}")
    print(f"Successfully attacked classes: {successful_classes} ({successful_classes/total_classes*100:.2f}%)")
    print(f"Failed to attack classes: {total_classes - successful_classes} ({(total_classes - successful_classes)/total_classes*100:.2f}%)")

@hydra.main(version_base=None, config_path="../configs")
def main(cfg: DictConfig):
    # print(cfg)
    # return
    if 'classes_file_path' in cfg:
        import json
        with open(cfg.classes_file_path, 'r') as f:
            classes = json.load(f)
    else:
        classes = cfg.classes
    other_kwargs = cfg.get('other_kwargs', {})
    L.seed_everything(cfg.seed)
    from models.noxeye import NoxEyePipeline
    model = NoxEyePipeline(**cfg.model_kwargs)
    attack(
        model=model,
        dataset_type=cfg.dataset_type,
        data_path=cfg.data_path,
        classes=classes,
        templates=cfg.templates,
        img_size=cfg.img_size,
        save=cfg.save,
        save_dir=cfg.save_dir,
        save_name=cfg.save_name,
        batch_size=cfg.batch_size,
        start=cfg.start,
        end=cfg.end,
        **other_kwargs
    )

if __name__ == "__main__":
    main()  
