import os, copy, json, glob
import wandb, hydra, logging, time
import math, re, random
from collections import defaultdict
from typing import Dict, Any, Optional
import numpy as np
from omegaconf import DictConfig, OmegaConf

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from transformers.trainer_callback import TrainerCallback

from data.base import get_dataloader
from data.graph_dataset import StreamingGraphDataset, GraphDataset, edge_to_str, path_to_str
from graph_tokenizers.numerical_tokenizer import NumericalTokenizer
from models import Transformer
from callbacks.logger import Evaluator
from generate import mp_generate_data

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OmegaConf.register_new_resolver("subtract", lambda x, y: int(x) - int(y))
OmegaConf.register_new_resolver("multiply", lambda x, y: round(float(x) * float(y)))
OmegaConf.register_new_resolver("divide", lambda x, y: float(x) / float(y))
OmegaConf.register_new_resolver("halve", lambda x: int(x) // 2)
OmegaConf.register_new_resolver("exp", lambda x, y: int(x**y))
OmegaConf.register_new_resolver("all_ppl", lambda n, layers: [n for _ in range(int(layers))])
OmegaConf.register_new_resolver("num_to_ppl", lambda n, layers: [n] + [1 for _ in range(int(layers)-1)])
OmegaConf.register_new_resolver("prod_list", lambda x: math.prod(x))
OmegaConf.register_new_resolver("oneminus", lambda *args: 1.0 - sum(args))
OmegaConf.register_new_resolver("cat", lambda lst, key=None: '-'.join([str(p) for p in lst]) if key is None else '-'.join([str(p[key]) for p in lst]))



def compute_grad_norm(model):
    with torch.no_grad():
        total_norm = torch.zeros((), device=model.device, dtype=torch.float64)
        for p in model.parameters():
            if p.grad is None or p.grad.numel() == 0:
                continue
            g = p.grad.detach()
            if g.is_sparse:
                g = g.coalesce().values()
            grad_norm = g.float().norm(2)
            total_norm += (grad_norm * grad_norm).to(total_norm.dtype)
        return total_norm.sqrt().item()

def compute_param_norm(model):
    with torch.no_grad():
        total_norm = torch.zeros((), device=model.device, dtype=torch.float64)
        for p in model.parameters():
            if p is None or p.numel() == 0:
                continue
            param = p.detach()
            if param.is_sparse:
                param = param.coalesce().values()
            param_norm = param.float().norm(2)
            total_norm += (param_norm * param_norm).to(total_norm.dtype)
        return total_norm.sqrt().item()


def get_lr(iter_num, config):
    if not config.decay_lr:
        return config.learning_rate
    # warmup + cosine
    if iter_num < config.warmup_steps:
        return config.learning_rate * (iter_num + 1) / config.warmup_steps
    if iter_num > config.lr_decay_iters:
        return config.min_lr
    ratio = (iter_num - config.warmup_steps) / (config.lr_decay_iters - config.warmup_steps)
    coef = 0.5 * (1.0 + math.cos(math.pi * ratio))
    return config.min_lr + coef * (config.learning_rate - config.min_lr)

def get_model(cfg: DictConfig, tokenizer: Any):
        
    if cfg.model.name == "gpt2":
        model = AutoModelForCausalLM.from_pretrained(cfg.model.name)
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.pad_token = tokenizer.pad_token
    elif cfg.model.name == "transformer":
        model = Transformer(cfg.model, tokenizer)
    else:
        raise ValueError(f"Model {cfg.model.name} not found")
    return model

def get_tokenizer(cfg: DictConfig):
    if cfg.model.tokenizer == "numerical":
        m = max([param.m for param in cfg.data.params])
        tokenizer = NumericalTokenizer(m)
    else:
        tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
        if 'gpt' in cfg.model.name:
            tokenizer.padding_side = "left"
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer

def get_training_args(cfg: DictConfig):
    training_args = TrainingArguments(
        output_dir=cfg.save_dir,
        num_train_epochs=cfg.train.num_epochs,
        per_device_train_batch_size=cfg.train.batch_size,
        per_device_eval_batch_size=cfg.train.batch_size,
        learning_rate=cfg.train.learning_rate,
        weight_decay=cfg.train.weight_decay,
        gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
        max_grad_norm=cfg.train.max_grad_norm,
        eval_strategy="steps",
        eval_steps=cfg.train.eval_interval,
        eval_on_start=True,
        save_strategy="epoch",
        save_total_limit=cfg.eval.save_total_limit,
        fp16=torch.cuda.is_available(),
        report_to="wandb" if cfg.use_wandb else "none",
        logging_dir=os.path.join(cfg.save_dir, "logs"),
        logging_steps=cfg.train.log_interval,
        remove_unused_columns=False,
        warmup_steps=cfg.train.warmup_steps,
        run_name=cfg.train.run_name
    )
    return training_args
    

def get_lengths(example_graph, tokenizer):
    edge_str = edge_to_str(example_graph['edge_list'], example_graph['source'], example_graph['goal'])
    edge_tokens = tokenizer.encode(edge_str)
    edge_length = len(edge_tokens)
    path_length = example_graph['path_length']
    path_nodes = example_graph['policy_nodes']
    path = [random.sample(path_nodes[i], 1)[0] for i in range(path_length)]
    path_str = path_to_str(path)
    path_tokens = tokenizer.encode(path_str)
    path_length = len(path_tokens)
    return edge_length, path_length

def set_global_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@hydra.main(config_path="../configs", config_name="train", version_base=None)
def train(cfg: DictConfig) -> None:

    set_global_seed(cfg.seed)
    train_base_seed = cfg.seed + 1337
    val_base_seed = cfg.seed + 7331

    OmegaConf.set_struct(cfg, False)


    if cfg.data.descrip in ["horizon", "finetune"]:
        cat = lambda lst, key=None: '-'.join([str(p) for p in lst]) if key is None else '-'.join([str(p[key]) for p in lst])
        cfg.data.descrip += f"_layers={cfg.data.params[0].layers}_ratio={cat(cfg.data.ratios)}_pass={cat(cfg.data.params,'num_pass_layers')}_rule={cat(cfg.data.params,'node_rule')}_load={cfg.model.load_model}"
        if 'horizon' in cfg.data.descrip:
            cfg.data.descrip += f"_shuffle={cat(cfg.data.params,'edge_shuffle_rule')}"
    logger.info(f"Data description: {cfg.data.descrip}")


    logger.info(f"\tUsing data.ratios: {[round(r, 2) for r in cfg.data.ratios]}")
    ratios = cfg.data.ratios
    assert len(ratios) == len(cfg.data.params), "data.ratios must have the same length as data.params"


    ##### Set up device 
    logger.info("========== Initializing... ==========")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    cfg.train.num_workers = max(1, os.cpu_count() // 4)
    logger.info(f"Setting num_workers to {cfg.train.num_workers}")
    cfg.eval.logprob_dir = os.path.join(cfg.root, cfg.eval.logprob_dir)
    logger.info(f"Setting logprob_dir to {cfg.eval.logprob_dir}")
    cfg.eval.checkpoint_dir = os.path.join(cfg.root, cfg.eval.checkpoint_dir)
    logger.info(f"Setting checkpoint_dir to {cfg.eval.checkpoint_dir}")

    ##### Set up wandb run name and filename
    model_str = f"D={cfg.model.depth}_d={cfg.model.width}_H={cfg.model.num_heads}"
    data_str = cfg.data.descrip
    if cfg.train.train_online and not cfg.train.online.generate_chunks_first:
        train_str = f"online"
    elif cfg.train.train_online and cfg.train.online.generate_chunks_first:
        train_str = f"offline_n={cfg.train.online.num_chunks}x{cfg.train.online.chunk_size}_shuff={cfg.train.online.shuffle_input}"
    else:   
        train_str = f"offline_n={cfg.train.offline.sample_size}"
    opt_str = f"lr={cfg.optimizer.learning_rate}_b={cfg.train.batch_size}"
    filename = f"{data_str}_{train_str}_{opt_str}_{model_str}"
    run_name = filename
    cfg.wandb.run_name = f"{cfg.wandb.run_name}_{run_name}" if cfg.wandb.run_name is not None else run_name
    logger.info(f"Run name: {cfg.wandb.run_name}")

    ##### Set up tokenizer
    tokenizer = get_tokenizer(cfg)


    logger.info("\n========== Preparing training data... ==========")
    start_time = time.time()

    ##### Load train data
    train_start_time = time.time()
    train_data = None
    if cfg.train.offline.load_data: # Load train data from file
        logger.info(f"Loading data from file {cfg.data.train_data_name}")
        train_filename = os.path.join(cfg.root, cfg.data.data_dir, cfg.train.train_data_name)
        train_size = cfg.train.online.chunk_size if cfg.train.train_online else cfg.train.offline.sample_size
        train_data = json.load(open(train_filename, 'r'))[:train_size]
        logger.info(f"Loaded {len(train_data)} train graphs from {train_filename}")
    elif (not cfg.train.train_online) and (not cfg.train.offline.load_data): # Offline: generate train data
        logger.info(f"Generating {cfg.train.offline.sample_size} train graphs for offline training")
        train_data = mp_generate_data(cfg.data.name, cfg.train.offline.sample_size, cfg.data, chunk_idx=0, seed=train_base_seed)
    train_time = time.time() - train_start_time
    logger.info(f"Training data prepared in {train_time:.2f}s")

    ##### Load validation data
    logger.info("\n========== Preparing evaluation data... ==========")
    val_start_time = time.time()
    logger.info(f"Generating {cfg.eval.sample_size} validation graphs with seed {val_base_seed}")
    val_datas = []
    val_datas_idxs = []
    evaluate_zeros = cfg.data.evaluate_zeros
    for param_idx, param in enumerate(cfg.data.params):
        if not evaluate_zeros and np.isclose(ratios[param_idx], 0.0):
            logger.info(f"\t Skipping param {param_idx} data since ratio {ratios[param_idx]} is 0")
            continue
        logger.info(f"\t Generating param {param_idx} data")
        val_data = mp_generate_data(cfg.data.name, cfg.eval.sample_size, param, chunk_idx=0, seed=val_base_seed, eval_mode=True)
        val_datas.append(val_data)
        val_datas_idxs.append(param_idx)
    val_time = time.time() - val_start_time
    logger.info(f"Validation data prepared in {val_time:.2f}s")

    logger.info("========== Checking path and edge lengths in validation data ==========")
    max_edge_length = 0
    max_path_length = 0
    for val_idx, val_data in enumerate(val_datas):
        logger.info(f"Checking validation data {val_idx}")
        for k,v in cfg.data.params[val_datas_idxs[val_idx]].items():
            logger.info(f"\t{k} = {v}")
        lengths = [get_lengths(graph, tokenizer) for graph in val_data]
        edge_lengths = [length[0] for length in lengths]
        path_lengths = [length[1] for length in lengths]
        max_edge_length = max(max_edge_length, max(edge_lengths))
        max_path_length = max(max_path_length, max(path_lengths))
        _edge_lengths = [len(graph['edge_list']) for graph in val_data]
        logger.info(f"# paths: {val_data[0]['num_paths']}, # edge tuples: {max(_edge_lengths)}, # edges: {max(edge_lengths)}, # pathnodes: {max(path_lengths)}\n")
    logger.info(f"Max edge length: {max_edge_length}, Max path length: {max_path_length}")

    cfg.model.target_max_length = max_path_length
    logger.info(f"Set target max length to {cfg.model.target_max_length}")

    logger.info(f"Checking if max length {cfg.model.max_length} is sufficient")
    threshold = max_edge_length + max_path_length + 8
    if cfg.model.max_length < threshold:
        logger.info(f"\tMax length {cfg.model.max_length} is insufficient")
        while cfg.model.max_length < threshold and cfg.model.max_length < 10000:
            cfg.model.max_length *= 2
            logger.info(f"\tNew max length {cfg.model.max_length}")
    assert cfg.model.max_length >= threshold, f"max length {cfg.model.max_length} < edge length {max_edge_length} + path length {max_path_length} + 1"
    logger.info(f"Set max length to {cfg.model.max_length}")



    logger.info("\n========== Setting up dataloaders... ==========")
    eval_start_time = time.time()

    ##### Set up train dataloader
    logger.info("Setting up train dataloader")
    train_dataset, train_dataloader = get_dataloader(
        StreamingGraphDataset if cfg.train.train_online else GraphDataset,
        tokenizer=tokenizer,
        cfg=cfg,
        batch_size=cfg.train.batch_size,
        shuffle_loader=True, # Randomly shuffle the order of samples,
        num_paths=cfg.train.num_paths,
        eval_mode=False,
        graph_type=cfg.data.name,
        data=train_data,
        use_cycle=True,
        base_seed=train_base_seed,
    )

    ##### Set up validation dataloader
    logger.info("Setting up validation dataloader")
    val_dataloaders = []
    for val_idx, val_data in enumerate(val_datas):  
        num_paths = val_data[0]['num_paths']
        max_repeats = 2048 // cfg.eval.prompt_batch_size
        val_batch_size = int(min(num_paths, max_repeats))
        logger.info(f"\tDataloader for data {val_idx} with num_paths {num_paths} and batch size{cfg.eval.prompt_batch_size}*{val_batch_size}")
        val_dataset, val_dataloader = get_dataloader(
            GraphDataset,
            data=val_data,
            tokenizer=tokenizer,
            cfg=cfg,
            batch_size=cfg.eval.prompt_batch_size*val_batch_size, # cfg.eval.batch_size,
            use_cycle=False,
            shuffle_loader=False,
            num_paths = 1000,
            eval_mode=True,
            graph_type=cfg.data.name,
            base_seed=val_base_seed,
        )
        val_dataloaders.append(val_dataloader)

    eval_time = time.time() - eval_start_time
    logger.info(f"Dataloaders prepared in {eval_time:.2f}s")

    total_time = time.time() - start_time
    logger.info(f"Total dataloader initialization took {total_time:.2f}s")


    logger.info("========== Initializing model and optimizers... ==========")

    logger.info("Initializing model")
    model = get_model(cfg, tokenizer)

    logger.info("Initializing optimizers")
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.optimizer.use_amp)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.optimizer.learning_rate, betas=tuple(cfg.optimizer.betas), weight_decay=cfg.optimizer.weight_decay)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

    if cfg.model.load_model:
        logger.info("Loading model from checkpoint")
        checkpoint = torch.load(cfg.model.load_model)
        model.model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logger.info(f"Loaded model from {cfg.model.load_model}")
    
    
    logger.info("========== Preparing evaluation and  wandb... ==========")
    if cfg.use_wandb:
        wandb.login(key=cfg.wandb.key, host=cfg.wandb.host)
        wandb.init(
            project=cfg.wandb.project,
            config=OmegaConf.to_container(cfg, resolve=True),
            name=cfg.wandb.run_name
        )
    
    save_model_interval = cfg.eval.save_model_interval
    eval_interval = cfg.eval.eval_interval
    logprob_interval = cfg.eval.logprob_interval
    generation_interval = cfg.eval.generation_interval

    save_logprobs = bool(cfg.eval.logprob_interval > 0)
    save_generations = bool(cfg.eval.save_generations)
    compute_generations = bool(cfg.eval.generation_interval > 0)
    save_model = bool(save_model_interval > 0)

    assert (not save_logprobs) or eval_interval <= logprob_interval
    assert (not save_logprobs) or logprob_interval % eval_interval == 0
    assert (not compute_generations) or eval_interval <= generation_interval
    assert (not compute_generations) or generation_interval % eval_interval == 0

    if save_logprobs:
        os.makedirs(cfg.eval.logprob_dir, exist_ok=True)
        with open(os.path.join(cfg.eval.logprob_dir, 'config.json'), 'w') as f:
            json.dump(OmegaConf.to_container(cfg, resolve=True), f, indent=2)
        logger.info(f"Saved config to {os.path.join(cfg.eval.logprob_dir, 'config.json')}")
    
    if save_model:
        os.makedirs(cfg.eval.checkpoint_dir, exist_ok=True)
        logger.info(f"Saving model to {cfg.eval.checkpoint_dir}")
    
    evaluator = Evaluator(
        model=model,
        tokenizer=tokenizer,
        eval_dataloaders=val_dataloaders,
        loss_fn=loss_fn,
        dataset_type='star',  
        use_tqdm=True,
        cfg=cfg
    )
    logger.info(f"Evaluator initialized with num_paths={evaluator.num_paths} paths")
    
    logger.info("========== Starting training... ==========")
    if cfg.model.name != "transformer":
        training_args = get_training_args(cfg)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
        )
        trainer.train()
        return
    
    logger.info("Iteration 0 evaluation")
    eval_outputs, eval_logprobs = evaluator.evaluate(
        iters=0,
        compute_logprobs=save_logprobs,
        compute_generations=save_generations,
        compute_accuracies=True
    ) 
 
    if save_logprobs or save_generations:
        save_file = os.path.join(cfg.eval.logprob_dir, f"logprobs_{0}.npz")
        logger.info(f"\tSaving logprobs at iteration {0} to {save_file}")
        np.savez_compressed(save_file,
                            **{f"eval_{i}": logprobs for i, logprobs in enumerate(eval_logprobs) if save_logprobs},
                            **{f"eval_{i}_greedy": outputs['greedy']['logprobs'] for i, outputs in enumerate(eval_outputs) if save_generations},
                            **{f"eval_{i}_temp": outputs['temp']['logprobs'] for i, outputs in enumerate(eval_outputs) if save_generations}
                            )
    logger.info("Iteration 0 evaluation complete")
    
    iters = 0
    while iters < cfg.train.max_iters:
        lr = get_lr(iters, cfg.optimizer)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        batch = next(train_dataloader)
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(device)
        
        with torch.autocast(device_type=device.type, dtype=torch.float16):
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
            loss = loss_fn(outputs, batch['labels'])

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if iters % cfg.eval.log_interval == 0:
            grad_norm = compute_grad_norm(model)
            param_norm = compute_param_norm(model)
            logger.info(f"Iter: {iters}, Loss: {loss.item():.4f}, LR: {lr:.6f}, Grad Norm: {grad_norm:.4f}, Param Norm: {param_norm:.4f}")
            
            if cfg.use_wandb:
                wandb.log({
                    "train/loss": loss.item(),
                    "train/lr": lr,
                    "train/grad_norm": grad_norm,
                }, step=iters)
        
        iters += 1

        # log evaluation metrics
        if iters % eval_interval == 0:
            iter_save_logprobs = save_logprobs and iters % logprob_interval == 0
            iter_compute_generations = compute_generations and iters % generation_interval == 0
            iter_save_generations = save_generations and iters % generation_interval == 0
            eval_outputs, eval_logprobs = evaluator.evaluate(
                iters,
                compute_logprobs=iter_save_logprobs,
                compute_generations=iter_compute_generations,
                compute_accuracies=True
                )

            # Save logprobs and generations
            if iter_save_logprobs or iter_save_generations:
                save_file = os.path.join(cfg.eval.logprob_dir, f"logprobs_{iters}.npz")
                logger.info(f"Saving logprobs at iteration {iters} to {save_file}")
                np.savez_compressed(save_file,
                                    **{f"eval_{i}": logprobs for i, logprobs in enumerate(eval_logprobs) if iter_save_logprobs},
                                    **{f"eval_{i}_greedy": outputs['greedy']['logprobs'] for i, outputs in enumerate(eval_outputs) if iter_save_generations},
                                    **{f"eval_{i}_temp": outputs['temp']['logprobs'] for i, outputs in enumerate(eval_outputs) if iter_save_generations}
                                    )

        if save_model and iters % save_model_interval == 0:
            checkpoint = {
                'epoch': iters, 
                'model_state_dict': model.model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            checkpoint_file = os.path.join(cfg.eval.checkpoint_dir, f"checkpoint_{iters}.pth")
            torch.save(checkpoint, checkpoint_file)
            logger.info(f"Saved model at iteration {iters} to {checkpoint_file}")

    # Clean up data directory after training
    chunk_dir = os.path.join(cfg.root, 'data', cfg.wandb.run_name)
    if os.path.exists(chunk_dir):
        logger.info(f"Cleaning up chunk directory: {chunk_dir}")
        for chunk_file in glob.glob(os.path.join(chunk_dir, "*.json")):
            os.remove(chunk_file)
        os.rmdir(chunk_dir)
    
    # Save final model
    # if cfg.save_dir:
    #     trainer.save_model(os.path.join(cfg.save_dir, "final_model"))
    #     tokenizer.save_pretrained(os.path.join(cfg.save_dir, "final_model"))
    
    if cfg.use_wandb:
        wandb.finish()

if __name__ == "__main__":
    train() 