import os
import random
from pathlib import Path
from typing import Dict, List, Optional
import torch
import gc


from datasets import Dataset, concatenate_datasets
from dotenv import load_dotenv

from .openmodel import LoRAModelManager
from .train_loop import train
from .utils.configs import (  # NOQA
    ApibenchDataConfig,
    MLLMDataConfig,
    HuggingBench1DataConfig,
    HuggingBench2DataConfig,
)
from .utils.parser import TrainParser
from .utils.prepareDataset import convert_to_conversational, load_dataset_json
from .utils.utility import set_seed
# from .utils.wandb import WandbLogger

PACKAGE_ROOT = Path(__file__).resolve().parent
PROJECT_ROOT = PACKAGE_ROOT.parent
load_dotenv(PROJECT_ROOT / ".env")


cache_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.hf_cache"))
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["TOKENIZERS_CACHE"] = cache_dir
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir


def get_dataset_config(experience_name: str):
    """Get the dataset configuration for a given experience name."""
    if experience_name == "apibench":
        return ApibenchDataConfig()
    elif experience_name == "mllm":
        return MLLMDataConfig()
    elif experience_name == "hugging-bench-1":
        return HuggingBench1DataConfig()
    elif experience_name == "hugging-bench-2":
        return HuggingBench2DataConfig()
    else:
        raise ValueError(f"Unknown experience name: {experience_name}")


def sample_replay_data(
    previous_datasets: Dict[str, Dataset],
    replay_percentage: Optional[float] = None,
    replay_num_samples: Optional[int] = None,
    seed: Optional[int] = None,
) -> List[Dataset]:
    """
    Sample replay data from previous experiences.

    Args:
        previous_datasets: Dictionary mapping experience names to their datasets
        replay_percentage: Percentage of samples to replay (e.g., 0.1 for 10%)
        replay_num_samples: Fixed number of samples to replay
        seed: Random seed for reproducibility

    Returns:
        List of sampled datasets from previous experiences
    """
    if not previous_datasets:
        return []

    replay_datasets = []

    for exp_name, dataset in previous_datasets.items():
        dataset_size = len(dataset)

        if dataset_size == 0:
            continue

        # Determine number of samples to replay
        if replay_num_samples is not None:
            num_samples = min(replay_num_samples, dataset_size)
        elif replay_percentage is not None:
            num_samples = max(1, int(dataset_size * replay_percentage))
        else:
            # No replay configured
            continue

        # Sample from the dataset
        if seed is not None:
            random.seed(seed)

        indices = random.sample(range(dataset_size), num_samples)
        sampled_dataset = dataset.select(indices)
        replay_datasets.append(sampled_dataset)

        print(
            f"  Replaying {num_samples} samples from {exp_name} (out of {dataset_size} total)"
        )

    return replay_datasets


def main():
    parser = TrainParser()
    train_config = parser.parse_args()

    print(train_config)

    if train_config.seed is not None:
        # Set seed for reproducibility
        set_seed(train_config.seed)

    # Initialize WandB logger
    # wandb_key = os.getenv("WANDB_API_KEY")
    # if wandb_key:
    #     wandb_logger = WandbLogger(wandb_key, train_config, mode="train")
    # else:
    #     wandb_logger = None
    #     print(
    #         "Warning: WANDB_API_KEY not found in environment variables. Skipping WandB logging."
    #     )
    
    lora_paths = [
        f"./cco/experiments/{adapter}" for adapter in train_config.lora_adapters
    ]
    model = LoRAModelManager(config=train_config, lora_paths=lora_paths)
    
    experiences = train_config.experience_names
    
    # Store datasets from previous experiences for replay
    previous_experience_datasets: Dict[str, Dataset] = {}

    # Train on each experience sequentially for seqquential finetuning or replay
    if train_config.mode in ["sequential-finetuning", "replay", "merging"]:
        for exp_idx, experience_name in enumerate(experiences):
                
            print(f"\n{'=' * 80}")
            print(
                f"Training on Experience {exp_idx + 1}/{len(experiences)}: {experience_name}"
            )
            print(f"{'=' * 80}\n")

            
            # Get dataset configuration for current experience
            dataset_config = get_dataset_config(experience_name)
            model_index_name = train_config.model_indices[exp_idx] if train_config.model_indices is not None else None
            retriever_name = train_config.retriever if train_config.retriever is not None else None  
            
            # Load and convert training dataset
            dataset_train = convert_to_conversational(
                raw_data=load_dataset_json(dataset_config.train_set),
                tokenizer=model.tokenizer,
                model_index_name=model_index_name,
                retriever_name=retriever_name,
            )

            # Load and convert validation dataset
            dataset_val = convert_to_conversational(
                raw_data=load_dataset_json(dataset_config.val_set),
                tokenizer=model.tokenizer,
                model_index_name=model_index_name,
                retriever_name=retriever_name,
            )

            # If no_validation is True, combine train and val sets into a single training set
            if train_config.no_validation:
                print(
                    "no_validation is True: Combining train and val sets into a single training set"
                )
                dataset_train = concatenate_datasets([dataset_train, dataset_val])
                dataset_val = None  # Set to None so it's not used for evaluation

            if train_config.mode == "replay":
                # Store the original training dataset (before replay) for future replay
                # This needs to be done before we add replay data
                original_dataset_for_replay = dataset_train

                # Sample replay data from previous experiences if configured
                if exp_idx > 0 and (
                    train_config.replay_percentage is not None
                    or train_config.replay_num_samples is not None
                ):
                    print("\nSampling replay data from previous experiences:")
                    replay_datasets = sample_replay_data(
                        previous_datasets=previous_experience_datasets,
                        replay_percentage=train_config.replay_percentage,
                        replay_num_samples=train_config.replay_num_samples,
                        seed=train_config.seed,
                    )

                    if replay_datasets:
                        # Concatenate replay data with current training data
                        all_datasets = [dataset_train] + replay_datasets
                        dataset_train = concatenate_datasets(all_datasets)
                        current_size = len(original_dataset_for_replay)
                        replay_size = sum(len(d) for d in replay_datasets)
                        print(
                            f"  Combined dataset size: {len(dataset_train)} (current: {current_size}, replay: {replay_size})"
                        )
            elif train_config.mode == "sequential-finetuning":
                if exp_idx > 0:
                    base_path = f"./cco/experiments/{experiences[exp_idx - 1]}-{train_config.variant_name}{f'-{train_config.extra_info}' if train_config.extra_info != '' else ''}"
                    # get checkpoint path of previous experience
                    dir_names = os.listdir(base_path)
                    dir_name = [d for d in dir_names if d.startswith("checkpoint-")][0]
                    checkpoint_path = os.path.join(base_path, dir_name)
                    print(f"Loading LoRA adapter from {checkpoint_path} for sequential finetuning.")
                    # delete previous model to free memory
                    del model
                    torch.cuda.empty_cache()
                    gc.collect()
                
                    model = LoRAModelManager(config=train_config, lora_paths=[checkpoint_path])
            
            elif train_config.mode == "merging":
                del model
                torch.cuda.empty_cache()
                gc.collect()
                model = LoRAModelManager(config=train_config, lora_paths=None)

            print(f"\nSTART TRAINING on {experience_name}")
            train(
                trainConfig=train_config,
                model=model,
                dataset_train=dataset_train,
                dataset_val=dataset_val,
                experience_name=experience_name,
                #wandb_logger=wandb_logger,
            )
            
            if train_config.mode == "replay":
                # Store the original training dataset (before replay) for future replay
                previous_experience_datasets[experience_name] = original_dataset_for_replay

            print(f"Completed training on {experience_name}\n")
    
    if train_config.mode == "joint-training":
        # concatenate all datasets for joint training
        all_train_datasets = []
        all_val_datasets = []
        for exp_idx, experience_name in enumerate(experiences):
            print(f"\nLoading dataset for experience: {experience_name}")
            dataset_config = get_dataset_config(experience_name)
            
            model_index_name = train_config.model_indices[0] if train_config.model_indices is not None else None
            retriever_name = train_config.retriever if train_config.retriever is not None else None  
            
            dataset_train = convert_to_conversational(
                raw_data=load_dataset_json(dataset_config.train_set),
                tokenizer=model.tokenizer,
                model_index_name=model_index_name,
                retriever_name=retriever_name,
            )
            dataset_val = convert_to_conversational(
                raw_data=load_dataset_json(dataset_config.val_set),
                tokenizer=model.tokenizer,
                model_index_name=model_index_name,
                retriever_name=retriever_name,
            )
            all_train_datasets.append(dataset_train)
            all_val_datasets.append(dataset_val)

        # Concatenate all training and validation datasets
        dataset_train = concatenate_datasets(all_train_datasets)
        dataset_val = concatenate_datasets(all_val_datasets)
        
        print(f"\nSTART JOINT TRAINING: {experiences}")
        train(
            trainConfig=train_config,
            model=model,
            dataset_train=dataset_train,
            dataset_val=dataset_val,
            experience_name=experience_name,
            #wandb_logger=wandb_logger,
        )  
    
if __name__ == "__main__":
    main()