import os
import glob
import logging
import datetime

from transformers import TrainerCallback
from transformers import TrainingArguments, TrainerState, TrainerControl
import os

from dataclasses import dataclass
import transformers
from typing import List, Tuple, Dict

IGNORE_INDEX = -100

logger = logging.getLogger()

class LoggerCallback(TrainerCallback):

    def on_train_begin(self, args, state, control, **kwargs):
        
        self.start_time = datetime.datetime.now()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_local_process_zero:
            return
        
        if 'loss' not in logs:
            return
        
        loss_msg = ' '.join(["%s: %s" % (k, v) for k, v in logs.items() if 'loss' in k])
        now = datetime.datetime.now()
        pass_time = now - self.start_time
        rest_time = pass_time * (state.max_steps - state.global_step) / state.global_step
        eta = now + rest_time

        pt_min = pass_time.seconds // 60
        pass_time = '%.2d:%.2d' % (pt_min // 60 + pass_time.days * 24, pt_min % 60)

        rt_min = rest_time.seconds // 60
        rest_time = '%.2d:%.2d' % (rt_min // 60 + rest_time.days * 24, rt_min % 60)

        logger.info(
            'step: %d epoch: %.2f %s lr: %.4g passed time: %s rest time: %s eta: %s',
            state.global_step, state.epoch, loss_msg, logs.get('learning_rate', 0),
            pass_time, rest_time, eta.strftime('%m/%d %H:%M')
        )

class RemoveStateCallback(TrainerCallback):

    def remove_state(self, args, step):
        step = int(step)

        if step <= 0:
            return

        step_dir =  os.path.join(args.output_dir, f'checkpoint-{step}')
        logger.info('Remove state in %s', step_dir)

        remove_paths = [
            os.path.join(step_dir, 'latest'), # deepspeed state
            os.path.join(step_dir, f'global_step{step}'), # deepspeed state
            os.path.join(step_dir, 'optimizer.pt'), # optimizer state
            os.path.join(step_dir, 'scheduler.pt'), # scheduler state
            os.path.join(step_dir, 'generation_config.json'), # generation config
            os.path.join(step_dir, 'trainer_state.json'), # trainer state
            os.path.join(step_dir, 'training_args.bin'), # training args
            os.path.join(step_dir, 'zero_to_fp32.py')
        ]

        remove_paths.extend(glob.glob(os.path.join(step_dir, 'rng_state_*.pth'))) # numpy random state

        for path in remove_paths:
            if os.path.exists(path):
                os.system('rm -rf %s' % path)

    def on_save(self, args, state, control, **kwargs):

        if not state.is_world_process_zero:
            return
        
        self.remove_state(args, state.global_step - state.save_steps)
    
    def on_train_end(self, args, state, control, **kwargs):
        
        if not state.is_world_process_zero:
            return
        
        self.remove_state(args, state.global_step)


class StepCheckpointCallback(TrainerCallback):
    def __init__(self, save_interval, output_dir, external_validation):
        self.save_interval = save_interval
        self.output_dir = output_dir
        self.trainer = None
        self.external_validation = external_validation

    def set_trainer(self, trainer):
        print("set_trainer in step checkpoint callback")
        self.trainer = trainer

    def save_model(self, checkpoint_dir, state: TrainerState):
        print("save_model in step checkpoint callback")
        import torch.distributed as dist
        rank = dist.get_rank() if dist.is_initialized() else 0
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.trainer.save_model(checkpoint_dir)
        print(f"rank: {rank} Global step: {state.global_step} (Epoch {int(state.epoch)}). Saved checkpoint at {checkpoint_dir}")


    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        import torch.distributed as dist
        rank = dist.get_rank() if dist.is_initialized() else 0
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        print(f"[on_step_end] rank={rank} world_size={world_size}")
        print(f"[on_step_end] state.global_step={state.global_step}")
        print(f"[on_step_end] self.save_interval={self.save_interval}")
        print(f"[on_step_end] state.epoch={state.epoch}")
        print("on_step_end in step checkpoint callback")
        if state.global_step % self.save_interval == 0:
            checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-step-{state.global_step}")
            self.save_model(checkpoint_dir, state)
            control.should_save = True 
            if self.external_validation:
                control.should_training_stop = True
            return control


@dataclass
class DataCollatorForLeanFinderDualLossDataset:
    """
    simple collator for text only data.
    """
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, features: List[Tuple[str, List[str], Dict]]):
        query_max_len = 210
        passage_max_len = 610
        all_queries = [f[0] for f in features]
        all_passages = []
        for f in features:
            all_passages.extend(f[1])
        all_queries = [q[0] for q in all_queries]
        all_passages = [p[0] for p in all_passages] 
        contrastive_query_collated = self.tokenizer(
            all_queries,
            padding=False, 
            truncation=True,
            max_length=query_max_len-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )
        contrastive_passage_collated = self.tokenizer(
            all_passages,
            padding=False, 
            truncation=True,
            max_length=passage_max_len-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )

        contrastive_query_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in contrastive_query_collated['input_ids']]
        contrastive_passage_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in contrastive_passage_collated['input_ids']]

        
        contrastive_query_collated = self.tokenizer.pad(
            contrastive_query_collated,
            padding=True, 
            pad_to_multiple_of=16,
            return_attention_mask=True,
            return_tensors='pt',
        )
        contrastive_passage_collated = self.tokenizer.pad(
            contrastive_passage_collated,
            padding=True, 
            pad_to_multiple_of=16,
            return_attention_mask=True,
            return_tensors='pt',
        )
        all_dpo_samples = [f[2] for f in features]
        all_queries = [sample["prompt"] for sample in all_dpo_samples]
        all_chosen = [sample["chosen"] for sample in all_dpo_samples]
        all_rejected = [sample["rejected"] for sample in all_dpo_samples]

        dpo_queries_collated = self.tokenizer(
            all_queries,
            padding=False, 
            truncation=True,
            max_length=query_max_len-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )
        dpo_chosen_collated = self.tokenizer(
            all_chosen,
            padding=False, 
            truncation=True,
            max_length=passage_max_len-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )
        dpo_rejected_collated = self.tokenizer(
            all_rejected,
            padding=False, 
            truncation=True,
            max_length=passage_max_len-1,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )

        dpo_queries_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in dpo_queries_collated['input_ids']]
        dpo_chosen_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in dpo_chosen_collated['input_ids']]
        dpo_rejected_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in dpo_rejected_collated['input_ids']]


        dpo_queries_collated = self.tokenizer.pad(
            dpo_queries_collated,
            padding=True, 
            pad_to_multiple_of=16,
            return_attention_mask=True,
            return_tensors='pt',
        )
        dpo_chosen_collated = self.tokenizer.pad(
            dpo_chosen_collated,
            padding=True, 
            pad_to_multiple_of=16,
            return_attention_mask=True,
            return_tensors='pt',
        )
        dpo_rejected_collated = self.tokenizer.pad(
            dpo_rejected_collated,
            padding=True, 
            pad_to_multiple_of=16,
            return_attention_mask=True,
            return_tensors='pt',
        )

        result = {
            "contrastive_query_input_ids": contrastive_query_collated["input_ids"],
            "contrastive_query_attention_mask": contrastive_query_collated["attention_mask"],
            "contrastive_passage_input_ids": contrastive_passage_collated["input_ids"],
            "contrastive_passage_attention_mask": contrastive_passage_collated["attention_mask"],
            "prompt_input_ids": dpo_queries_collated["input_ids"],
            "prompt_attention_mask": dpo_queries_collated["attention_mask"],
            "chosen_input_ids": dpo_chosen_collated["input_ids"],
            "chosen_attention_mask": dpo_chosen_collated["attention_mask"],
            "rejected_input_ids": dpo_rejected_collated["input_ids"],
            "rejected_attention_mask": dpo_rejected_collated["attention_mask"],
        }
        return result






    
    