import os
import argparse
import math
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import (
    MixedPrecision
)
from torch.distributed.fsdp.wrap import (
   transformer_auto_wrap_policy
)
from torch.utils.data import Dataset, DataLoader
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from itertools import chain
from dataset.wikitext import LazyWikitext, AQSGD_Wikitext, LargeBatchWikitext
from tqdm import tqdm
from monitor_logger import monitor
from compressor.compressor import Compressor
from torch.distributed.fsdp import StateDictType
import random
import numpy as np

from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import yaml
from gsm8k_utils import is_correct


def load_config(config_path="llama2_finetune_config.yaml"):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

def set_seed(seed):
    """
    Set random seed for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Optional: this makes deterministic algorithms slower but more reproducible
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def prepare_dist_env(args):
    # Set seed before initializing distributed environment
    if args.seed is not None:
        set_seed(args.seed)
    
    rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))
    torch.cuda.set_device(rank)
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=world_size,
        rank=rank
    )
    
    
    print(f"[Rank {dist.get_rank()}] local_rank: {rank}, "
          f"cuda device: {torch.cuda.current_device()}")


def fsdp_wrap_llama(model):
    def get_polices():
        mixed_precision_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            # Gradient communication precision.
            reduce_dtype=torch.bfloat16,
            # Buffer precision.
            buffer_dtype=torch.bfloat16,
            cast_forward_inputs=True,
        )

        wrapping_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls=set([LlamaDecoderLayer])
        )
        
        return mixed_precision_policy, wrapping_policy
    mixed_precision_policy, wrapping_policy = get_polices()
    fsdp_model = FSDP(
        model,
        auto_wrap_policy=wrapping_policy,
        sharding_strategy=None,  
        device_id=torch.cuda.current_device(),
        mixed_precision=mixed_precision_policy,   
    )
    return fsdp_model


def train_one_epoch(model, dataloader, optimizer, rank, config):
    hook_handles = []
    compression_config = config.get('compression_config', {})
    for layer_name, layer_config in compression_config.items():
        layer_idx = layer_config['layer_idx']
        hook_fn = create_compression_hook(layer_idx, layer_config)
        handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
        hook_handles.append(handle)

    try:
        model.train()
        total_loss = 0.0
        
        iterator = tqdm(dataloader, desc=f"Training", disable=rank != 0)
        
        for step, batch in enumerate(iterator):
            input_ids = batch["input_ids"].cuda()
            labels = batch["labels"].cuda()
            
            
            indices = batch.get("indices", None)
            if indices is not None:
                for layer_idx, layer_config in compression_config.items():
                    setattr(model.model.layers[int(layer_config['layer_idx'])], 'current_indices', indices)

            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            
            if indices is not None:
                for layer_idx, layer_config in compression_config.items():
                    delattr(model.model.layers[int(layer_config['layer_idx'])], 'current_indices')

            total_loss += loss.item()
            
            if rank == 0:
                iterator.set_postfix({'loss': f'{loss.item():.4f}'})
                monitor.log_metrics({
                    "train/loss": loss.item(),
                })

        return total_loss / len(dataloader)

    finally:
        for handle in hook_handles:
            handle.remove()


def get_datasets(tokenizer, args):
    """
    Load and preprocess datasets based on config
    """
    dataset_name = args.dataset
    
    if dataset_name == "wikitext":
        raw_datasets = load_dataset(
            "wikitext",
            "wikitext-2-v1",
            cache_dir=args.cache_dir,
            trust_remote_code=True
        )
    elif dataset_name == "arxiv":
        raw_datasets = load_dataset(
            "ds3lab/ac-sgd-arxiv21",
            cache_dir="arxiv_cache",
            trust_remote_code=True,
        )
        raw_datasets["validation"] = raw_datasets["test"]
    elif dataset_name == "gsm8k":
        raw_datasets = load_dataset(
            "gsm8k",
            "main",
            cache_dir=args.cache_dir,
            trust_remote_code=True
        )
        # For GSM8K, create validation split from train
        train_val_split = raw_datasets["train"].train_test_split(
            test_size=0.1,  # Use 10% of training data as validation
            seed=args.seed
        )
        raw_datasets["train"] = train_val_split["train"]
        raw_datasets["validation"] = train_val_split["test"]  # Use the test split as validation
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    if dist.get_rank() == 0:
        print("Dataset loaded:", raw_datasets)
        for split in raw_datasets:
            print(f"{split} split size:", len(raw_datasets[split]))

    # Dataset-specific preprocessing
    def tokenize_function(examples):
        if dataset_name == "gsm8k":
            # Remove raw_datasets["text"] column reference since GSM8K has different structure
            return tokenizer(
                examples["question"],  # Only tokenize the question for now
                truncation=True,
                max_length=args.block_size,
                padding=False,
                return_tensors=None,
            )
        elif dataset_name == "arxiv":
            # Combine title and abstract for arXiv
            texts = [f"Abstract: {a}" for a in examples["abstract"]]
            return tokenizer(
                texts,
                truncation=True,
                max_length=args.block_size,
                padding=False,
                return_tensors=None,
            )
        else:  # wikitext
            return tokenizer(
                examples["text"],
                truncation=True,
                max_length=args.block_size,
                padding=False,
                return_tensors=None,
            )

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        remove_columns=raw_datasets["train"].column_names,  # Dynamically get columns to remove
        load_from_cache_file=not args.overwrite_cache,
        desc="Running tokenizer on dataset",
    )

    def group_texts(examples):
        if dataset_name == "gsm8k":
            # Format each example as a question-answer pair
            formatted_texts = []
            for i in range(len(examples["input_ids"])):
                question_tokens = examples["input_ids"][i]
                answer = raw_datasets["train"][i]["answer"]  # Get corresponding answer

                # Format: "Question: {question} Answer: {answer}"
                answer_tokens = tokenizer(f"\nAnswer: {answer}", add_special_tokens=False)["input_ids"]
                combined_tokens = question_tokens + answer_tokens

                # If it's too long, truncate. If it's too short, pad up to args.block_size
                if len(combined_tokens) > args.block_size:
                    combined_tokens = combined_tokens[:args.block_size]
                else:
                    combined_tokens += [tokenizer.pad_token_id] * (args.block_size - len(combined_tokens))

                formatted_texts.append(combined_tokens)

            result = {"input_ids": formatted_texts}
            result["labels"] = result["input_ids"].copy()
            return result
        else:
            
            concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
            total_length = len(concatenated_examples[list(examples.keys())[0]])        

            
            if total_length >= args.block_size:
                total_length = (total_length // args.block_size) * args.block_size

            
            result = {
                k: [
                    t[i : i + args.block_size] 
                    for i in range(0, total_length, args.block_size)
                ]
                for k, t in concatenated_examples.items()
            }
            result["labels"] = result["input_ids"].copy()

            
            return result

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=args.preprocessing_num_workers,
        load_from_cache_file=not args.overwrite_cache,
        desc=f"Grouping texts in chunks of {args.block_size}",
    )

    if dist.get_rank() == 0:
        print("Final processed dataset:", lm_datasets)
        print("Train examples:", len(lm_datasets["train"]))
        print("Validation examples:", len(lm_datasets["validation"]))

    return lm_datasets


def create_compression_hook(layer_idx, layer_config):
    compressor = None
    
    def compression_hook(module, input_tensor, output):
        nonlocal compressor
        if compressor is None:
            compressor = Compressor(
                input_shape=output[0].shape,  
                forward=layer_config['forward'],
                forward_params=layer_config['forward-params'],
                backward=layer_config['backward'],
                backward_params=layer_config['backward-params'],
                forward_EF=layer_config['forward-EF'],
                backward_EF=layer_config['backward-EF'],
                forward_EF_method=layer_config['forward-EF-method'],
                backward_EF_method=layer_config['backward-EF-method']
            )
        
        indices = getattr(module, 'current_indices', None)
        compressed = compressor(output[0], indices=indices)
        return (compressed,)
    
    return compression_hook

def setup_monitor(config):
    wandb_config = config.get('wandb', {})
    use_wandb = config.get('use_wandb', False)
    
    wandb_run_name = wandb_config.get('name', 'default-run')
    project_name = wandb_config.get('project', 'llama-compression')
    
    monitor.setup(
        exp_name=wandb_run_name,
        use_wandb=use_wandb,
        project_name=project_name,
        log_dir="log"
    )
    
    return wandb_run_name

def evaluate_model(model, eval_dataloader, rank, raw_eval_dataset=None,tokenizer=None,dataset_name=None):
    # Ensure no compression hooks are active during evaluation
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    gsm8k_correct = 0
    gsm8k_total = 0
    try:
        if dataset_name == "gsm8k" and raw_eval_dataset is not None and tokenizer is not None:
            
            gsm8k_total = len(raw_eval_dataset)
            iterator = tqdm(raw_eval_dataset, desc="Evaluating GSM8K (generate)", disable=(rank != 0))

            for sample in iterator:
                question = sample["question"]
                gold_answer = sample["answer"]

                prompt = "Question: " + question + "\nAnswer:"
                inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

                gen_out = model.generate(
                    **inputs,
                    max_new_tokens=128,
                    do_sample=False  
                )
                pred_text = tokenizer.decode(gen_out[0], skip_special_tokens=True)
                gsm8k_correct += is_correct(pred_text,gold_answer)

            if dist.is_initialized():
                gsm8k_correct_tensor = torch.tensor([gsm8k_correct], device="cuda", dtype=torch.float)
                dist.all_reduce(gsm8k_correct_tensor, op=dist.ReduceOp.SUM)
                gsm8k_correct = gsm8k_correct_tensor.item()

                gsm8k_total_tensor = torch.tensor([gsm8k_total], device="cuda", dtype=torch.float)
                dist.all_reduce(gsm8k_total_tensor, op=dist.ReduceOp.SUM)
                gsm8k_total = int(gsm8k_total_tensor.item())

            gsm8k_acc = gsm8k_correct / gsm8k_total if gsm8k_total > 0 else 0.0
            metrics = {"eval/gsm8k_accuracy": gsm8k_acc}

            
            if rank == 0:
                print(f"[{dataset_name}] Evaluation metrics:", metrics)
                # monitor.log_metrics(metrics)  

            dist.barrier()
            torch.cuda.empty_cache()
            return metrics
    except:
        pass
    
    with torch.no_grad():
        iterator = tqdm(eval_dataloader, desc=f"Evaluating", disable=rank != 0)
        for batch in iterator:
            # print("dist.get_rank()",dist.get_rank(),"batch",batch)
            input_ids = batch["input_ids"].cuda()
            labels = batch["labels"].cuda()
            
            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
            
            # Calculate accuracy
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            predictions = shift_logits.argmax(dim=-1)
            mask = shift_labels != -100
            correct = (predictions == shift_labels) & mask
            
            total_correct += correct.sum().item()
            total_tokens += mask.sum().item()
            total_loss += loss.item() * input_ids.size(0)
            torch.cuda.synchronize()
    
    # Aggregate metrics across all processes
    if dist.is_initialized():
        metrics = torch.tensor([total_loss, total_correct, total_tokens], device='cuda')
        dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
        total_loss, total_correct, total_tokens = metrics.tolist()
    
    avg_loss = total_loss / len(eval_dataloader.dataset)
    accuracy = total_correct / total_tokens if total_tokens > 0 else 0
    
    try:
        perplexity = math.exp(avg_loss)
    except OverflowError:
        perplexity = float("inf")
    
    metrics = {
        "eval/loss": avg_loss,
        "eval/perplexity": perplexity,
        "eval/accuracy": accuracy
    }
    
    monitor.log_metrics(metrics)
    torch.cuda.empty_cache()
    dist.barrier()
    return metrics

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_dir", type=str, default=None,
                       help="Where to store the cached datasets")
    parser.add_argument("--overwrite_cache", action="store_true",
                       help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--preprocessing_num_workers", type=int, default=4,
                       help="Number of processes for data preprocessing")
    parser.add_argument(
        "--config_path",
        type=str,
        default="gpt2_finetune_config.yaml",
        help="Path to the configuration file."
    )
    args = parser.parse_args()

    
    config = load_config(config_path=args.config_path)
    training_config = config.get("training", {})

    args.block_size = training_config.get("block_size", 1024)
    args.batch_size = training_config.get("batch_size", 2)
    args.epochs = training_config.get("epochs", 4)
    args.lr = training_config.get("learning_rate", 1e-4)
    args.model_name_or_path = training_config.get("model", "NousResearch/Llama-2-7b-hf")
    args.seed = config.get("seed", 42)
    args.dataset = training_config.get("dataset", "wikitext")

    
    prepare_dist_env(args)
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    
    if rank == 0:
        print(f"Loading tokenizer and model from {args.model_name_or_path} ...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

    
    datasets = get_datasets(tokenizer, args)

    
    train_dataset = datasets["train"]
    eval_dataset = datasets["validation"]

    
    if training_config.get('lazy_sampling', False):
        lazy_params = training_config.get('lazy_sampling_params', {})
        if lazy_params.get('schedule') == 'constant':
            p_t = lambda x: lazy_params.get('p_t', 0.5)
            train_dataset = LazyWikitext(
                train_dataset,
                p_t=p_t,
                batch_size=training_config.get('batch_size', 8),
                with_idx=True
            )
    elif training_config.get('aq_sgd', False):
        train_dataset = AQSGD_Wikitext(
            train_dataset,
            batch_size=training_config.get('batch_size', 8),
            with_idx=True
        )
    elif training_config.get('large_batch', False):
        large_batch_params = training_config.get('large_batch_params', {})
        param_k = large_batch_params.get('k', 1)
        train_dataset = LargeBatchWikitext(
            train_dataset,
            batch_size=training_config.get('batch_size', 8),
            k=param_k,
            with_idx=True
        )

    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=False
    )
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False,
        drop_last=False
    )

    def indexed_collator(features):
        
        max_len = max(len(f["input_ids"]) for f in features)
        
        
        padded_input_ids = []
        padded_labels = []
        indices_list = []
        
        for f in features:
            input_ids = f["input_ids"]
            labels = f["labels"]
            
            
            padded_ids = input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids))
            padded_lbl = labels + [-100] * (max_len - len(labels))  
            
            padded_input_ids.append(torch.tensor(padded_ids, dtype=torch.long))
            padded_labels.append(torch.tensor(padded_lbl, dtype=torch.long))
            
            
            if "indices" in f:
                indices_list.append(int(f["indices"]))
        
        batch = {
            "input_ids": torch.stack(padded_input_ids),
            "labels": torch.stack(padded_labels),
        }
        if len(indices_list) > 0:
            batch["indices"] = indices_list

        return batch

    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        collate_fn=indexed_collator,
        pin_memory=True
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        sampler=eval_sampler,
        collate_fn=indexed_collator,
        pin_memory=True
    )

    
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.float16,
        device_map=None,
    )
    model.cuda()
    fsdp_model = fsdp_wrap_llama(model)

    
    optimizer = torch.optim.SGD(fsdp_model.parameters(), lr=args.lr, momentum=0.9)

    
    wandb_run_name = setup_monitor(config)

    
    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)
        eval_sampler.set_epoch(epoch)

        avg_loss = train_one_epoch(fsdp_model, train_dataloader, optimizer, rank, config)
        monitor.log_metrics({
            "train/epoch_avg_loss": avg_loss,
        })

        eval_metrics = evaluate_model(fsdp_model, eval_dataloader, rank,raw_eval_dataset=eval_dataset,tokenizer=tokenizer,dataset_name=args.dataset)
        print(f"Rank {rank} - Epoch {epoch} eval_metrics: {eval_metrics}")

    
    with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
        state_dict = fsdp_model.state_dict()
        if rank == 0:
            save_path = "fsdp_llama_finetuned.pt"
            torch.save(state_dict, save_path)
            print(f" {save_path}")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()
