import os
import json
import hydra

import torch
import datasets
import numpy as np

from typing import Dict, List, Tuple, Any
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from torch.distributed import init_process_group, destroy_process_group

from data_module import TextForgetDatasetQA, convert_raw_data_to_model_format
from utils import get_model_identifiers_from_yaml

class TextForgetDatasetQA(Dataset):
    """
    Dataset for text forgetting or data attribution in question-answering tasks.

    Handles loading and preprocessing data for forget/retain splits in machine unlearning tasks.
    """

    def __init__(
        self, 
        data_path: str, 
        tokenizer: Any, 
        model_family: str, 
        max_length: int = 512, 
        split: str = "forget10", 
        loss_type: str = "idk"
    ):
        """
        Initialize the dataset with the given parameters.

        Args:
            data_path: Path to the dataset
            tokenizer: Tokenizer for the model
            model_family: Model family identifier (used to get configuration)
            max_length: Maximum sequence length
            split: Data split to use (e.g., "forget10" for 10% forget set)
            loss_type: Type of loss to use ("idk" or other)
        """
        super(TextForgetDatasetQA, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.forget_data = datasets.load_dataset(data_path, split)["train"]

        # Calculate retain split based on forget percentage
        retain_percent = 100 - int(split.replace("forget", ""))
        retain_split = f"retain{str(retain_percent).zfill(2)}"
        self.retain_data = datasets.load_dataset(data_path, retain_split)["train"]

        self.model_configs = get_model_identifiers_from_yaml(model_family)
        self.loss_type = loss_type

        if self.loss_type == "idk":
            self.split1, self.split2 = "idk", "retain"
            self.idontknowfile = "data/idontknow.jsonl"
            # Use context manager for file operations
            try:
                with open(self.idontknowfile, "r") as f:
                    self.idk = f.readlines()
            except FileNotFoundError:
                raise FileNotFoundError(f"IDontKnow file not found at {self.idontknowfile}")
        else:
            self.split1, self.split2 = "forget", "retain"

    def __len__(self) -> int:
        """Return the number of samples in the forget dataset."""
        return len(self.forget_data)

    def __getitem__(self, idx: int) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
        """
        Get a single data item by index.
        
        Returns both a forget and retain sample as a list of processed tensor tuples.
        """
        rets = []
        # Set fixed seed for reproducibility based on idx
        torch.manual_seed(idx)

        for data_type in [self.split1, self.split2]:
            # Use questions from forget set if split is idk or forget
            data = self.retain_data if data_type == "retain" else self.forget_data

            # Select index based on data type
            if data_type == "retain":
                # Get a deterministic but different sample for retain data
                rand_offset = torch.randint(0, len(self.retain_data), (1,)).item()
                cur_idx = (idx + rand_offset) % len(self.retain_data)
            else:
                cur_idx = idx

            question = data[cur_idx]['question']
            answer = data[cur_idx]['answer']

            if data_type == "idk":
                # Get a random answer from idk responses with fixed seed
                rand_pos = torch.randint(0, len(self.idk), (1,)).item()
                answer = self.idk[rand_pos].strip()

            converted_data = convert_raw_data_to_model_format(
                self.tokenizer, 
                self.max_length, 
                question, 
                answer, 
                self.model_configs
            )
            rets.append(converted_data)
        return rets

def convert_raw_data_to_model_format(
    tokenizer: Any, 
    max_length: int, 
    question: str, 
    answer: str, 
    model_configs: Dict[str, str]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Convert raw question-answer pairs to formatted input for the model.
    
    Args:
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length
        question: The question text
        answer: The answer text
        model_configs: Model configuration dictionary with tags
        
    Returns:
        Tuple of tensors (input_ids, labels, attention_mask, token_type_ids)
    """
    question_start_token = model_configs['question_start_tag']
    question_end_token = model_configs['question_end_tag']
    answer_token = model_configs['answer_tag']

    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer

    # Count question tokens for masking labels later
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True,
        return_token_type_ids=True  
    )

    # Pad sequences to max_length
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_token_type_ids = encoded['token_type_ids'] + [0] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length

    # Create labels, masking question tokens with -100
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids.copy()
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

    # Mask question tokens in labels
    for i in range(num_question_tokens):
        label[i] = -100

    return (
        torch.tensor(pad_input_ids),
        torch.tensor(label),
        torch.tensor(pad_attention_mask),
        torch.tensor(pad_token_type_ids)
    )


class AbstractDataAttribution:
    """
    Abstract base class for data attribution methods.
    
    Handles model and data loading with support for distributed training.
    """

    def __init__(
        self, 
        model_family: str, 
        model_path: str, 
        split: str, 
        attribution_method: str, 
        unify_method: str, 
        unify_tau: float, 
        forget_loss: str = "grad_ascent", 
        data_path: str = "locuslab/TOFU", 
        max_length: int = 500, 
        batch_size: int = 1,
        use_flash_attention_2: bool = False
    ):
        """
        Initialize the data attribution class.

        Args:
            model_family: Model family identifier
            model_path: Path to the model checkpoint
            split: Data split to use
            attribution_method: Method for attribution calculation
            unify_method: Method for unifying attribution scores
            unify_tau: Temperature parameter for unification
            forget_loss: Loss type for forgetting
            data_path: Path to the dataset
            max_length: Maximum sequence length
            batch_size: Batch size for processing
        """
        self.model_family = model_family
        self.model_path = model_path
        self.split = split
        self.unify_method = unify_method       
        self.attribution_method = attribution_method
        self.forget_loss = forget_loss
        self.data_path = data_path
        self.max_length = max_length
        self.batch_size = batch_size
        self.tau = unify_tau
        self.use_flash_attention_2 = use_flash_attention_2

        # Initialize distributed training settings
        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.is_distributed = self.world_size > 1

        # Setup device
        try:
            self.device = torch.device(f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu")
        except RuntimeError:
            print(f"Warning: CUDA device {self.local_rank} not available. Falling back to CPU.")
            self.device = torch.device("cpu")

    # 3. Add memory optimization to model loading
    def load_model(self) -> None:
        """Load model and tokenizer with memory optimizations."""
        if torch.cuda.is_available():
            torch.cuda.set_device(self.local_rank)
            # Clear cache before loading
            torch.cuda.empty_cache()

        model_cfg = get_model_identifiers_from_yaml(self.model_family)
        model_id = model_cfg["hf_key"]
        use_flash_attention_2 = self.use_flash_attention_2 and model_cfg["flash_attention2"] == "true"

        # Set environment variable for better memory allocation
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

        try:
            config = AutoConfig.from_pretrained(model_id)

            # Add memory-saving options
            load_options = {
                'config': config,
                'torch_dtype': torch.bfloat16,
                'trust_remote_code': True,
                'use_flash_attention_2': use_flash_attention_2,
                'low_cpu_mem_usage': True,  # Reduces CPU memory during loading
            }

            # Optional: device_map for automatic model partitioning
            if self.world_size <= 1:
                # For single GPU, you might want to try auto device map
                load_options['device_map'] = 'auto'
            else:
                # For multi-GPU, stick with the current approach
                load_options['device_map'] = None

            print(f">>> [Rank {self.local_rank}] Loading from checkpoint {self.model_path}")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path, 
                **load_options
            ).to(self.device)

            # Enable gradient checkpointing
            if hasattr(self.model, 'gradient_checkpointing_enable'):
                self.model.gradient_checkpointing_enable()

            self.tokenizer = AutoTokenizer.from_pretrained(model_id)
            print(f">>> [Rank {self.local_rank}] Model Checkpoint and Tokenizer loaded.")
        except Exception as e:
            raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")

    def load_data(self) -> None:
        """Load and prepare dataloaders for both forget and retain sets."""
        ds = TextForgetDatasetQA(
            self.data_path,
            tokenizer=self.tokenizer,
            model_family=self.model_family,
            max_length=self.max_length,
            split=self.split,
            loss_type=self.forget_loss,
        )

        ds_forget, ds_retain = ds.forget_data, ds.retain_data

        # Setup distributed samplers if needed
        if self.is_distributed:  
            sampler_forget = DistributedSampler(
                ds_forget, 
                num_replicas=self.world_size,
                rank=self.local_rank,
                shuffle=False
            )
            sampler_retain = DistributedSampler(
                ds_retain,
                num_replicas=self.world_size,
                rank=self.local_rank,
                shuffle=False
            )
            self.loader_forget = DataLoader(ds_forget, batch_size=self.batch_size, sampler=sampler_forget)
            self.loader_retain = DataLoader(ds_retain, batch_size=self.batch_size, sampler=sampler_retain)
        else:
            self.loader_forget = DataLoader(ds_forget, batch_size=self.batch_size, shuffle=False)
            self.loader_retain = DataLoader(ds_retain, batch_size=self.batch_size, shuffle=False)

        self.len_ds_forget = len(ds_forget)
        print(f">>> [Rank {self.local_rank}] Data Loader constructed.")

    def process_batch(self, batch: Dict[str, Any]) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[str], List[str]
    ]:
        """
        Process a batch of data into model-compatible format.
        
        Args:
            batch: A batch of data from the dataloader
            
        Returns:
            Tuple of (input_ids, labels, attention_mask, token_type_ids, questions, answers)
        """
        model_configs = get_model_identifiers_from_yaml(self.model_family)
        cvt_bch = {
            "input_ids": [],
            "labels": [],
            "attention_mask": [],
            "token_type_ids": [],
            "qs": [],
            "as": []
        }

        for i in range(min(self.batch_size, len(batch['question']))):
            question = batch['question'][i]
            answer = batch['answer'][i]

            input_ids, labels, attention_mask, token_type_ids = convert_raw_data_to_model_format(
                self.tokenizer, 
                self.max_length, 
                question, 
                answer, 
                model_configs
            )

            cvt_bch['input_ids'].append(input_ids)
            cvt_bch['labels'].append(labels)
            cvt_bch['attention_mask'].append(attention_mask)
            cvt_bch['token_type_ids'].append(token_type_ids)
            cvt_bch['qs'].append(question)
            cvt_bch['as'].append(answer)

        return (
            torch.stack(cvt_bch["input_ids"], dim=0).to(self.device),
            torch.stack(cvt_bch['labels'], dim=0).to(self.device),
            torch.stack(cvt_bch['attention_mask'], dim=0).to(self.device),
            torch.stack(cvt_bch['token_type_ids'], dim=0).to(self.device),
            cvt_bch['qs'], 
            cvt_bch['as']
        )


class DataAttribution(AbstractDataAttribution):
    """
    Implementation of data attribution methods for measuring sample influence.
    
    Supports gradient-based attribution methods in distributed environments.
    """

    def get_attribution(self) -> None:
        """
        Calculate attribution scores for each sample in the forget set.
        
        Supports different attribution methods (g_norm, g_prod) and saves results.
        """
        dict_q_score = {}

        if self.attribution_method == "g_norm":
            # Calculate data-wise gradient norm
            self._calculate_gradient_norm_scores(dict_q_score)
        elif self.attribution_method == "g_prod":
            # Calculate gradient product with average gradient
            self._calculate_gradient_product_scores(dict_q_score)
        else:
            raise ValueError(f"Unknown attribution method: {self.attribution_method}")

        # Save and process results
        self._save_and_process_results(dict_q_score)

    # 1. Add memory cleanup to gradient calculation functions
    def _calculate_gradient_norm_scores(self, dict_q_score: Dict[str, float]) -> None:
        """Calculate scores based on gradient norm for each sample with memory cleanup."""
        if hasattr(self, 'loader_retain'):
            del self.loader_retain

        # Enable gradient checkpointing to save memory
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()

        for batch in tqdm(self.loader_forget, desc=f'[Rank {self.local_rank}] Running data-wise grad for forget set'):
            # Clear CUDA cache before processing each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            batch = self.process_batch(batch=batch)
            input_id, label, attention_mask, _, question, _ = batch

            # Calculate loss and gradients
            loss = self.model(input_ids=input_id, labels=label, attention_mask=attention_mask).loss
            grads = torch.autograd.grad(loss, self.model.parameters())

            # Compute the norm and immediately release gradient tensors
            grad_norms = [torch.norm(grad, p=2).item() for grad in grads]
            grads_norm = np.mean(grad_norms)
            dict_q_score[question[0]] = grads_norm

            # Explicitly delete gradient tensors to free memory
            del grads, loss, input_id, label, attention_mask, batch, grad_norms
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # 4. Memory-optimized gradient product calculation
    def _calculate_gradient_product_scores(self, dict_q_score: Dict[str, float]) -> None:
        """Calculate scores based on gradient product with memory optimization."""
        avg_grad_path = self._get_avg_grad_path()

        # Load or calculate average gradient
        if os.path.exists(avg_grad_path):
            # Load to CPU first to save GPU memory
            avg_grad_retain = torch.load(avg_grad_path, map_location='cpu')
            print(f">>> [Rank {self.local_rank}] Loaded avg_grad from {avg_grad_path}")
        else:
            self._calculate_avg_grad(avg_grad_path)
            avg_grad_retain = torch.load(avg_grad_path, map_location='cpu')

        # Process batches with memory cleanup
        for batch in tqdm(self.loader_forget, desc=f'[Rank {self.local_rank}] Running data-wise grad for forget set'):
            # Clear cache before each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            batch = self.process_batch(batch=batch)
            input_id, label, attention_mask, _, question, _ = batch

            # Calculate loss and gradients
            loss = self.model(input_ids=input_id, labels=label, attention_mask=attention_mask).loss
            grads = torch.autograd.grad(loss, self.model.parameters())

            # Calculate dot product efficiently - move avg_grad to device as needed
            score = 0.0
            for i in range(len(grads)):
                # Move average gradient to same device as current gradient
                avg_grad_device = avg_grad_retain[i].to(grads[i].device)
                # Calculate dot product
                score += torch.tensordot(avg_grad_device, grads[i], dims=grads[i].dim()).item()
                # Free memory
                del avg_grad_device

            dict_q_score[question[0]] = score

            # Clean up tensors
            del grads, loss, input_id, label, attention_mask, batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def _get_avg_grad_path(self) -> str:
        """Get the path for average gradient file based on distributed setup."""
        if self.is_distributed:
            return os.path.join(self.model_path, f"RetainAvgGrad_split_{self.split}_rank_{self.local_rank}.pth")
        else:
            return os.path.join(self.model_path, f"RetainAvgGrad_split_{self.split}.pth")

    # 2. Memory-efficient average gradient calculation
    def _calculate_avg_grad(self, avg_grad_path: str) -> None:
        """Calculate the average gradient over the retain set with memory efficiency."""
        avg_grad_retain = None
        count = 0

        # Process in smaller chunks if dataset is large
        chunk_size = 10  # Process and average gradients in chunks
        chunk_grads = []

        for i, batch in enumerate(tqdm(self.loader_retain, desc=f'[Rank {self.local_rank}] Running AVG grad for retain set')):
            # Clear cache periodically
            if i % 5 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

            batch = self.process_batch(batch=batch)
            input_id, label, attention_mask, _, _, _ = batch

            # Calculate loss and gradients
            loss = self.model(input_ids=input_id, labels=label, attention_mask=attention_mask).loss
            grads = torch.autograd.grad(loss, self.model.parameters())

            # Store detached, CPU gradients to save GPU memory
            cpu_grads = [grad.detach().cpu() for grad in grads]
            chunk_grads.append(cpu_grads)

            # Clean up GPU tensors immediately
            del grads, loss, input_id, label, attention_mask
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Process chunks to avoid memory buildup
            if len(chunk_grads) >= chunk_size or i == len(self.loader_retain) - 1:
                # Average the gradients in the current chunk
                if not avg_grad_retain:
                    # First chunk, initialize the average
                    avg_grad_retain = [sum(g[i] for g in chunk_grads) / len(chunk_grads) for i in range(len(chunk_grads[0]))]
                    count = 1
                else:
                    # Update running average
                    chunk_avg = [sum(g[i] for g in chunk_grads) / len(chunk_grads) for i in range(len(chunk_grads[0]))]
                    for i in range(len(avg_grad_retain)):
                        avg_grad_retain[i] = (avg_grad_retain[i] * count + chunk_avg[i]) / (count + 1)
                    count += 1
                    del chunk_avg

                # Clear chunk after processing
                del chunk_grads
                chunk_grads = []
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Save average gradient to disk
        torch.save(avg_grad_retain, avg_grad_path)
        print(f">>> [Rank {self.local_rank}] Saved avg_grad to {avg_grad_path}")

        # Cleanup
        del avg_grad_retain
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def _save_and_process_results(self, dict_q_score: Dict[str, float]) -> None:
        """
        Save original scores, unify them, and save unified scores.

        Handles distributed and non-distributed environments.

        Args:
            dict_q_score: Dictionary of attribution scores by question
        """
        # Handle distributed environment
        if self.is_distributed:
            # Gather results from all processes
            all_dicts = [None] * self.world_size
            torch.distributed.all_gather_object(all_dicts, dict_q_score)

            # Merge dictionaries on rank 0
            if self.local_rank == 0:
                merged_dict = {}
                for d in all_dicts:
                    merged_dict.update(d)
                dict_q_score = merged_dict

                # Save original scores
                self._save_scores(dict_q_score, is_unified=False)

                # Unify scores and save unified scores
                unified_scores = self.unify(dict_q_score)
                self._save_scores(unified_scores, is_unified=True)
        else:
            # Save original scores
            self._save_scores(dict_q_score, is_unified=False)
            
            # Unify scores and save unified scores
            unified_scores = self.unify(dict_q_score)
            self._save_scores(unified_scores, is_unified=True)

    def _save_scores(self, scores: Dict[str, float], is_unified: bool) -> None:
        """
        Save attribution scores to a file.

        Args:
            scores: Dictionary of scores to save
            is_unified: Whether these are unified scores
        """
        # Create file path
        if is_unified:
            filename = f"{self.split}_{self.attribution_method}_{self.unify_method}t{self.tau}_influence_dict.json"
        else:
            filename = f"{self.split}_{self.attribution_method}_influence_dict_original.json"

        save_path = os.path.join(self.model_path, filename)

        # Skip if original file exists
        if not is_unified and os.path.exists(save_path):
            print(f">>> [Rank {self.local_rank}] File '{save_path}' already exists. Skipping...")
            return

        # Save scores
        try:
            with open(save_path, "w") as f:
                json.dump(scores, f, indent=4)

            file_type = "Unified" if is_unified else "Original"
            print(f">>> [Rank {self.local_rank}] {file_type} Score saved at {save_path}")
        except Exception as e:
            print(f">>> [Rank {self.local_rank}] Error saving scores: {str(e)}")

    def unify(self, score_dict: Dict[str, float]) -> Dict[str, float]:
        """
        Unify attribution scores using the specified method.

        Args:
            score_dict: Dictionary of original scores

        Returns:
            Dictionary of unified scores
        """
        # Make a copy to avoid modifying the original
        unified_dict = score_dict.copy()

        # Exponential unification (softmax-like)
        if self.unify_method == "exp":
            values = np.array(list(unified_dict.values()))
            # Use more numerically stable approach
            values_scaled = -values / self.tau
            values_shifted = values_scaled - np.max(values_scaled)  # For numerical stability
            exp_values = np.exp(values_shifted)
            denominator = np.sum(exp_values)

            for i, key in enumerate(unified_dict.keys()):
                unified_dict[key] = exp_values[i] / denominator * self.len_ds_forget

        # Power-law unification
        elif self.unify_method == "power":
            values = np.array(list(unified_dict.values()))
            min_value = np.min(values)
            eps = 1e-6  # Small epsilon to avoid division by zero

            # Calculate power-transformed values
            power_values = ((values - min_value + eps) ** (-self.tau))
            denominator = np.sum(power_values)
            
            for i, key in enumerate(unified_dict.keys()):
                unified_dict[key] = power_values[i] / denominator * self.len_ds_forget

        else:
            raise ValueError(f"Unknown unification method: {self.unify_method}")

        # Sanity check - sum should be approximately equal to dataset size
        sum_scores = sum(unified_dict.values())
        eps = 1e-5
        if not (1 - eps < sum_scores / self.len_ds_forget < 1 + eps):
            print(f"WARNING: Scores not properly normalized. Sum is {sum_scores}, expected {self.len_ds_forget}")

        print(f">>> [Rank {self.local_rank}] Unification Completed using method {self.unify_method}")
        return unified_dict


@hydra.main(version_base=None, config_path="config", config_name="data_attribution")
def main(cfg: Any) -> None:
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    # Set memory-saving environment variables
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    # Initialize distributed process group if needed
    if world_size > 1:
        try:
            init_process_group(backend="nccl")
            torch.cuda.set_device(local_rank)
        except Exception as e:
            print(f"Error initializing process group: {str(e)}")
            return

    if local_rank == 0:
        print(f"Starting data attribution with {world_size} GPUs")

    # Process each tau value with error handling
    for tau in tqdm(cfg.tau_values, desc=f"[Rank {local_rank}] Processing tau values"):
        try:
            # Clean GPU memory before starting
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Initialize attribution object with memory efficient settings
            attribution = DataAttribution(
                model_family=cfg.model_family, 
                model_path=cfg.model_path,
                split=cfg.split, 
                attribution_method=cfg.attribution_method,
                unify_method=cfg.unify_method,
                unify_tau=tau,
                use_flash_attention_2=cfg.use_flash_attention_2,
                batch_size=1,  # Keep batch size small
            )

            # Process with memory cleanup between stages
            attribution.load_model()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            attribution.load_data()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            attribution.get_attribution()
            print(f">>> [Rank {local_rank}] Data Attribution Completed for tau={tau}.")

            # Clean up at the end
            del attribution
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            print(f">>> [Rank {local_rank}] Error during attribution with tau={tau}: {str(e)}")
            # Try to recover memory on error
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Clean up distributed process group
    if world_size > 1:
        destroy_process_group()

if __name__ == "__main__":
    main()