import os
import ast
import gc
import json
import pickle as pkl
from datetime import datetime, timedelta, timezone
from argparse import ArgumentParser
from typing import Dict, Any, List

import torch
import deepspeed
import wandb
from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
import torch.nn.functional as F
from deepspeed.utils.logging import LoggerFactory
from torch.utils.data import DistributedSampler
from tqdm import tqdm
import torch.distributed as dist
from datasets import load_dataset

from lorax.custom_hf import (
    LoraxLlamaForCausalLM,
    LoraxQwen2ForCausalLM,
    wrap_linear,
)
from lorax.utils import (
    add_filehandler,
    set_seed,
    get_global_rank,
    get_global_size,
    set_no_grad,
    adjust_deepspeed_config,
    save_pretrain,
    to_device,
)
from lorax.data_modules import (
    prepare_datasets,
    preprocess_train_dataset,
    CustomDataCollatorForSeq2SeqForTrain,
)

logger = LoggerFactory.create_logger(__name__)

def parse_args():
    parser = ArgumentParser()
    
    # modol
    parser.add_argument("--model_name", type=str)
    
    # data
    parser.add_argument("--datasets", type=ast.literal_eval)
    parser.add_argument("--max_seq_len", type=int)
    parser.add_argument("--num_workers", type=int)
    
    #output
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--logging_steps", type=int)
    parser.add_argument("--save_steps", type=int)
    parser.add_argument("--wandb_enable", action="store_true")
    parser.add_argument("--wandb_key", type=str)
    parser.add_argument("--wandb_project", type=str, default="lorax")
    parser.add_argument("--wandb_run_name", type=str, default="genetics_cochrane_genetics")
    
    # train
    parser.add_argument("--num_train_epochs", type=int)
    parser.add_argument("--per_device_train_batch_size", type=int)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--weight_decay", type=float)
    parser.add_argument("--warmup_steps", type=int)
    parser.add_argument("--gradient_clipping", type=float, default=1.0)
    
    parser.add_argument("--lora_r", type=int)
    parser.add_argument("--num_loras", type=int)
    parser.add_argument("--target_modules", type=ast.literal_eval)
    
    
    parser.add_argument("--contrastive_loss_weight", type=float)
    parser.add_argument("--contrastive_temp", type=float)
    parser.add_argument("--contrastive_targets_path", type=str)
    parser.add_argument("--num_contrastive_targets", type=int)
    
    parser.add_argument("--deepspeed_config", type=str)
    parser.add_argument("--local_rank", type=int, default=0)
    
    args = parser.parse_args()
    return args


def train_once(
    model: PreTrainedModel,
    batch: Dict[str, Any],
    contrastive_targets: List[str],
    constrastive_loss_weight: float,
    temperature: float,
):  
    def get_last_hidden_state(hidden_states, attention_mask):
            # hidden_states: batch_size, seq_len, hidden_size
            # attention_mask: (batch_size, seq_len)
            last_hidden_state_idx = torch.sum(attention_mask, dim=-1) - 1
            last_hidden_state = hidden_states[torch.arange(hidden_states.size(0)), last_hidden_state_idx]
            return last_hidden_state
        
    model.train()
    source_input_ids = batch.pop("source_input_ids")
    target_input_ids = batch.pop("target_input_ids")
    source_attention_mask = batch.pop("source_attention_mask")
    target_attention_mask = batch.pop("target_attention_mask")
    
    source_inputs = {
        "input_ids": source_input_ids,
        "attention_mask": source_attention_mask,
        "lora_B_idx": batch["lora_B_idx"],
    }
    
    target_inputs = {
        "input_ids": target_input_ids,
        "attention_mask": target_attention_mask,
        "lora_B_idx": batch["lora_B_idx"],
    }
    
    source_representations = [] # num_contrastive_targets x (batch_size, seq_len, lora_r)
    target_representations = [] # num_contrastive_targets x (batch_size, seq_len, lora_r)
    _ = model(
        contrastive_targets=contrastive_targets,
        representations=source_representations,
        **source_inputs
    )
    _ = model(
        contrastive_targets=contrastive_targets,
        representations=target_representations,
        **target_inputs
    )
    
    contrastive_loss = torch.tensor(0)
    for source_representation, target_representation in zip(source_representations, target_representations):
        source_states = get_last_hidden_state(source_representation, source_attention_mask) # (batch_size, lora_r)
        target_states = get_last_hidden_state(target_representation, target_attention_mask) # (batch_size, lora_r)
        
        # contrastive loss
        source_states = F.normalize(source_states, p=2, dim=-1)
        target_states = F.normalize(target_states, p=2, dim=-1)
        
        sim_matrix = torch.einsum(
            "ik, jk -> ij",
            source_states,
            target_states,
        ) / temperature
        
        # construct labels
        labels = torch.arange(source_states.size(0)).to(source_states.device)
        
        contrastive_loss = contrastive_loss + F.cross_entropy(sim_matrix, labels)
    
    if contrastive_loss > 0:
        contrastive_loss = contrastive_loss / len(contrastive_targets)
        contrastive_loss = contrastive_loss * constrastive_loss_weight
    
    outputs = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"],
        lora_B_idx=batch["lora_B_idx"],
        contrastive_targets=[]
    )
    
    ori_loss = outputs.loss
    loss = ori_loss + contrastive_loss
    
    model.backward(loss)
    model.step()
    
    ori_loss = ori_loss.detach().cpu().item()
    contrastive_loss = contrastive_loss.detach().cpu().item()
        
    return ori_loss, contrastive_loss
    

if __name__ == "__main__":
    args = parse_args()

    deepspeed.init_distributed()   
    dist.barrier()
    
    utc_time = datetime.now(timezone.utc)
    utc_plus_8 = timezone(timedelta(hours=8)) # UTC+8
    start_time = utc_time.astimezone(utc_plus_8)
    start_time_str = start_time.strftime("%Y-%m-%d-%H-%M-%S")
    args.output_dir = os.path.join(args.output_dir, start_time_str)
    os.makedirs(args.output_dir, exist_ok=True)
    add_filehandler(logger, os.path.join(args.output_dir, "logging", f"log-{start_time_str}.log"))
    args.contrastive_targets = pkl.load(open(args.contrastive_targets_path, "rb"))[:args.num_contrastive_targets]

    if get_global_rank() == 0:
        path = os.path.join(args.output_dir, "config.json")
        with open(path, "w") as f:
            json.dump(vars(args), f, indent=4)
        logger.info(f"Full config saved to {path}")
        logger.info(args)
        
    set_seed(42)
    device = torch.device(get_global_rank())
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    tokenizer.padding_side = "right"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    model_kwargs = {
        "pretrained_model_name_or_path": args.model_name,
        "trust_remote_code": False,
        "torch_dtype": torch.bfloat16,
        "device_map": device,
    }
    
    if "Qwen2" in args.model_name:
        config = LoraxQwen2ForCausalLM.config_class.from_pretrained(args.model_name)
        config._attn_implementation = 'flash_attention_2'
        model = LoraxQwen2ForCausalLM.from_pretrained(
            **model_kwargs,
            config=config,
        )
    elif "Llama" in args.model_name:
        config = LoraxLlamaForCausalLM.config_class.from_pretrained(args.model_name)
        config._attn_implementation = 'flash_attention_2'
        model = LoraxLlamaForCausalLM.from_pretrained(
            **model_kwargs,
            config=config,
        )
    else:
        raise ValueError(f"Model {args.model_name} not supported.")
    
    lorax_config = {
        "lora_r": args.lora_r,
        "num_loras": args.num_loras,
    }
    wrap_linear(
        model=model,
        target_modules=args.target_modules,
        config=lorax_config,
    )
    set_no_grad(model, logger)
    
    model.gradient_checkpointing_enable(
        gradient_checkpointing_kwargs={"use_reentrant": False},
    )
    model.enable_input_require_grads()
    
    # data
    train_dataset = prepare_datasets(args.datasets)
    train_dataset = train_dataset.map(
        preprocess_train_dataset,
        batched=True,
        remove_columns=train_dataset.column_names,
        num_proc=args.num_workers,
        fn_kwargs={
            "tokenizer": tokenizer,
            "cutoff_len": args.max_seq_len,
        }
    )

    if get_global_rank() == 0:
        logger.info(f"Train dataset: {train_dataset}")
        logger.info(f"Train dataset example:\n{train_dataset[0]}")
        logger.info(f"{tokenizer.decode(train_dataset[0]['input_ids'])}")
    
    if get_global_size() > 1:
        train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=get_global_size(),
            rank=get_global_rank(),
            shuffle=True,
        )
    else:
        train_sampler = None
    
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.per_device_train_batch_size,
        sampler=train_sampler,
        num_workers=args.num_workers,
        collate_fn=CustomDataCollatorForSeq2SeqForTrain(
            tokenizer,
        )
    )

    
    num_train_steps = len(train_dataloader) * args.num_train_epochs
    args.num_train_steps = num_train_steps
    dp_config = adjust_deepspeed_config(args)
    
    no_decay = ["bias", "norm.weight"]
    regular_lr_group, no_decay_group = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay) and p.requires_grad:
            no_decay_group.append(p)
        elif p.requires_grad:
            regular_lr_group.append(p)
    
    optimizer_grouped_parameters = [
        {"params": regular_lr_group, "lr": args.learning_rate},
        {"params": no_decay_group, "weight_decay": 0.0},
    ]
    
    model, optimizer, _, _ = deepspeed.initialize(
        model=model,
        config=dp_config,
        model_parameters = optimizer_grouped_parameters,
    )
    
    if get_global_rank() == 0 and args.wandb_enable:
        wandb.login(
            key=args.wandb_key,
        )
        batch_size = args.per_device_train_batch_size * get_global_size() * args.gradient_accumulation_steps
        wandb.init(
            project=args.wandb_project,
            name=(
                f"{args.wandb_run_name}_{start_time_str}"
                f"_bs{batch_size}_lr{args.learning_rate}_epoch{args.num_train_epochs}"
                f"_lora_r{args.lora_r}"
                f"_weight{args.contrastive_loss_weight}"
                f"_temp{args.contrastive_temp}"
                f"_num_contrastive_targets{args.num_contrastive_targets}"
            ),
            config=vars(args),
        )
    
    batch_step = 0
    step_loss_list = []
    step_ori_loss_list = []
    step_contrastive_loss_list = []
    dist.barrier()
    for i in range(args.num_train_epochs):
        for b_idx, batch in enumerate(
            tqdm(
                train_dataloader,
                desc=f"Epoch {i}/{args.num_train_epochs}",
                disable=not get_global_rank() == 0,
            )
        ):
            # train
            batch = to_device(obj=batch, device=device)
            step_ori_loss, step_contrastive_loss = train_once(
                model=model,
                batch=batch,
                contrastive_targets=args.contrastive_targets,
                constrastive_loss_weight=args.contrastive_loss_weight,
                temperature=args.contrastive_temp,
            )
            step_loss = step_ori_loss + step_contrastive_loss
            step_loss_list.append(step_loss)
            step_ori_loss_list.append(step_ori_loss)
            step_contrastive_loss_list.append(step_contrastive_loss)
            batch_step += 1

            if batch_step % args.gradient_accumulation_steps == 0:
                step_loss = sum(step_loss_list) / len(step_loss_list)
                step_ori_loss = sum(step_ori_loss_list) / len(step_ori_loss_list)
                step_contrastive_loss = sum(step_contrastive_loss_list) / len(step_contrastive_loss_list)
                step_loss_list = []
                step_ori_loss_list = []
                step_contrastive_loss_list = []


            step = batch_step / args.gradient_accumulation_steps 

            # log
            if step % args.logging_steps == 0 and step != 0:
                dist.barrier()
                step_loss = torch.tensor(step_loss, device=device, dtype=torch.float32)
                step_ori_loss = torch.tensor(step_ori_loss, device=device, dtype=torch.float32)
                step_contrastive_loss = torch.tensor(step_contrastive_loss, device=device, dtype=torch.float32)
                dist.all_reduce(step_loss, op=dist.ReduceOp.SUM)
                dist.all_reduce(step_ori_loss, op=dist.ReduceOp.SUM)
                dist.all_reduce(step_contrastive_loss, op=dist.ReduceOp.SUM)
                dist.barrier()
                lr = optimizer.param_groups[0]["lr"]
                grad_norm = optimizer._global_grad_norm
                step_loss = step_loss / get_global_size()
                step_ori_loss = step_ori_loss / get_global_size()
                step_contrastive_loss = step_contrastive_loss / get_global_size()
                
                if get_global_rank() == 0:
                    step_record = {
                        "step": step,
                        "lr": lr,
                        "ori_loss": step_ori_loss,
                        "contrastive_loss": step_contrastive_loss,
                        "loss": step_loss,
                        "grad_norm": grad_norm,
                    }
                    if args.wandb_enable:
                        wandb.log(step_record)
                    text_formatted = f"Step: {step} | LR: {lr:.6f} | Ori loss: {step_ori_loss:.4f} | Contrastive loss: {step_contrastive_loss:.4f} | Loss: {step_loss:.4f} | Grad norm: {grad_norm:.4f}"
                    logger.info(text_formatted)
               
            # save
            if step % args.save_steps == 0:
                save_pretrain(
                    model=model,
                    output_dir=args.output_dir,
                    epoch=i,
                    step=step,
                    logger=logger,
                )
            dist.barrier()        
            gc.collect()
            torch.cuda.empty_cache()

        save_pretrain(
            model=model,
            output_dir=args.output_dir,
            epoch=i,
            step=step,
            logger=logger,
        )
        dist.barrier()
    
    if get_global_rank() == 0:
        if args.wandb_enable:
            wandb.finish()
        utc_time = datetime.now(timezone.utc)
        utc_plus_8 = timezone(timedelta(hours=8)) # UTC+8
        end_time = utc_time.astimezone(utc_plus_8)
        logger.info(f"Training finished in {end_time - start_time}")
    
        
                    
                    
            
            
    
    
    
     
    