# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import inspect

import numpy as np
# import tyro
import json
import os
import typing
from typing import Optional, Union
import tqdm
from datetime import timedelta

import datasets
import torch
import transformers
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import is_deepspeed_available, InitProcessGroupKwargs
from transformers import LlamaTokenizer, get_scheduler, AutoTokenizer
from torch.optim import AdamW
from modeling.optimizer.anyprecision_optimizer import AnyPrecisionAdamW
from sklearn.metrics import classification_report, accuracy_score
from peft import PeftModel

from modeling.ft_datasets.instruction_dataset import InstructionDataset
from modeling.lluga.modeling_gllama import UniGTE
from modeling.common import freeze, print_trainable_parameters
from utils.data_utils import output_decode

from trl.core import (
    set_seed,
)
from trl.trainer import BaseTrainer, RunningMoments

import torch.distributed as dist

if is_deepspeed_available():
    import deepspeed


def print_rank_0(*args):
    try:
        if dist.get_rank() == 0:
            print(*args, flush=True)
    except:
        print(*args, flush=True)


class MyTrainer(BaseTrainer):
    def __init__(
        self,
        model_args, 
        data_args, 
        training_args,
    ):
        self.model_args = model_args
        self.data_args = data_args
        self.training_args = training_args

        # initial seed for reproducible experiments
        set_seed(self.training_args.seed)

        # Step 1: Initialize Accelerator
        init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
        # transformers.logging.set_verbosity_error()
        self.accelerator = Accelerator(
            log_with="wandb", 
            kwargs_handlers=[ddp_kwargs, init_kwargs],
            gradient_accumulation_steps=self.training_args.gradient_accumulation_steps
        )
        self.accelerator.init_trackers(
            project_name=f"{self.training_args.project_name}",
            init_kwargs={
                "wandb": {
                    "group": self.data_args.datasets,
                    "name": self.training_args.run_name,
                    "config": self.model_args
                }
            },
        )
        
        # Step 2: prepare tokenizer
        self.tokenizer, self.tokenizer_dec = self._prepare_tokenizer()

        # Step 3: prepare dataloader
        self.dataloader, self.names = self._prepare_dataloader()

        # Step 4: prepare model
        self.model = self._create_model()
        self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        print_rank_0(self.model_params)

        # Step 5: Initialize optimizer and lr_scheduler
        if not self.training_args.inference:
            self.optimizer, self.lr_scheduler = self._prepare_optimizer_and_scheduler()
            (
                self.model,
                self.optimizer,
                self.dataloader,
                self.lr_scheduler,
            ) = self.accelerator.prepare(
                self.model,
                self.optimizer,
                self.dataloader,
                self.lr_scheduler,
            )
        else:
            self.model, *self.dataloader = self.accelerator.prepare(self.model, *self.dataloader)
            self.model = self.accelerator.unwrap_model(self.model)
        
        self.is_distributed = self.accelerator.num_processes > 1
        # init the current step
        self.current_step = 0

        # post process for PP
        if not getattr(self.model, "is_sequential_parallel", False):
            self.current_device = self.accelerator.device
        else:
            self.current_device = torch.device("cuda:0")
        print('device:', self.current_device)

        self.running = RunningMoments(self.accelerator)

    def _prepare_tokenizer(self):
        if self.model_args.model_arch == 'llama3':
            tokenizer = AutoTokenizer.from_pretrained(self.model_args.gt_enocder_path)
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer = LlamaTokenizer.from_pretrained(self.model_args.gt_enocder_path)
            tokenizer.pad_token=tokenizer.unk_token
        tokenizer.truncation_side = 'left'
        special={'additional_special_tokens': [' <MEM {}>'.format(i) for i in range(self.model_args.memory_token_nums)]+['<-FineTune->']}   # Add a new special token as place holder
        tokenizer.add_special_tokens(special)

        if self.model_args.gt_enocder_path != self.model_args.llm_path:
            tokenizer_dec = LlamaTokenizer.from_pretrained(self.model_args.llm_path)
            tokenizer_dec.pad_token=tokenizer_dec.unk_token
            tokenizer_dec.add_special_tokens(special)
        else:
            tokenizer_dec = None

        return tokenizer, tokenizer_dec
    
    def _prepare_dataloader(self):

        if self.training_args.inference:
            dataloaders = []
            names = []
            datasets = list(self.data_args.datasets.split(','))
            for dataset_name in datasets:
                dataset = InstructionDataset([self.tokenizer, self.tokenizer_dec], self.accelerator, self.model_args, self.data_args, self.training_args, 
                                            mode='test', dataset=dataset_name)
                kwargs = {
                    'batch_size': self.training_args.per_device_eval_batch_size,
                    'collate_fn': dataset.collate_fn,
                    'shuffle': False,
                    'drop_last': False,
                    'pin_memory': True,
                    # 'num_workers': 2,
                }
                dataloader = torch.utils.data.DataLoader(dataset, **kwargs)
                dataloaders.append(dataloader)
                names.append(dataset_name)
        else:
            dataset = InstructionDataset([self.tokenizer, self.tokenizer_dec], self.accelerator, self.model_args, self.data_args, self.training_args, 
                                        mode='train')
            kwargs = {
                'batch_size': self.training_args.per_device_train_batch_size,
                'collate_fn': dataset.collate_fn,
                'shuffle': True,
                'drop_last': True,
                'pin_memory': True,
            }
            dataloaders = torch.utils.data.DataLoader(dataset, **kwargs)
            names = None
        
        return dataloaders, names
    
    def _create_model(self):
        if self.training_args.inference:
            model = UniGTE.load_model(self.model_args, self.training_args, self.training_args.output_dir)
            if self.model_args.model_arch == 'llama3':
                model.decoder.generation_config.pad_token_id = self.tokenizer.pad_token_id
        else:
            model = UniGTE(self.model_args, self.training_args)
            if self.training_args.fix_encoder:
                freeze(model.gllama)
            if self.training_args.fix_mem:
                model.memory_token_embed.load_state_dict(torch.load('/UniGTE/ft-models/icae_lora/memory_token_embed.pt'))
                freeze(model.memory_token_embed)
            print_trainable_parameters(model)
        print_rank_0(model)

        return model
    
    def _prepare_optimizer_and_scheduler(self):
        if self.training_args.bf16:
            proj_params = [param for name, param in self.model.named_parameters() if (name.startswith('projector.') and param.requires_grad) or (name.startswith('edge_projector.') and param.requires_grad)]
            proj_param_ids = {id(param) for param in proj_params}
            # proj_params_names = [name for name, param in self.model.named_parameters() if name.startswith('projector.') and param.requires_grad]
            graph_params = [param for name, param in self.model.named_parameters() if (name.startswith('graph_position_theta') and param.requires_grad) or (name.startswith('rel_position_encoder.') and param.requires_grad)]
            graph_param_ids = {id(param) for param in graph_params}
            # graph_params_names = [name for name, param in self.model.named_parameters() if (name.startswith('graph_position_theta') and param.requires_grad) or (name.startswith('rel_position_encoder.') and param.requires_grad)]
            other_params_ids = proj_param_ids| graph_param_ids
            # base_params = [param for name, param in model.named_parameters() if 'multi_modal_projector' not in name]
            # base_params = filter(lambda p: id(p) not in proj_params, model.parameters())
            base_params = [param for param in self.model.parameters() if id(param) not in other_params_ids and param.requires_grad]
            # base_params_names = [name for name, param in self.model.named_parameters() if id(param) not in other_params_ids and param.requires_grad]
            print_rank_0(json.dumps(sorted(list({
                '.'.join(name.split('.')[0:3]) for name, param in self.model.named_parameters() if param.requires_grad
            })), indent=2), len(proj_params), len(base_params))

            optimizer = AnyPrecisionAdamW(
                [
                    {"params": base_params, 'lr': self.training_args.learning_rate},
                    {"params": proj_params, 'lr': self.training_args.proj_learning_rate},
                    {"params": graph_params, 'lr': self.training_args.graph_learning_rate},
                ],
                # model.parameters(),
                # lr=config.lr,
                momentum_dtype=torch.bfloat16,
                variance_dtype=torch.bfloat16,
                use_kahan_summation=False,
            )
        else:
            print_rank_0(self.model.parameters().items())
            optimizer = AdamW(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.training_args.learning_rate,
            )

        epoch_steps = len(self.dataloader)
        train_steps = epoch_steps * self.training_args.num_train_epochs
        num_training_steps = train_steps // self.training_args.gradient_accumulation_steps
        warmup_steps = self.training_args.warmup_ratio * num_training_steps

        print_rank_0('*' * 20, epoch_steps, train_steps, warmup_steps, self.accelerator.distributed_type, '*' * 20)

        num_warmup_steps = warmup_steps * self.accelerator.num_processes

        lr_scheduler = get_scheduler(
            self.training_args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )

        return optimizer, lr_scheduler


    def log_stats(self, stats: dict):
        if not self.accelerator.is_main_process:
            return

        logs = collections.OrderedDict()
        logs.update(stats)

        if self.training_args.log_with == "tensorboard":
            # update the current step
            self.current_step += 1

        self.accelerator.log(
            logs,
            step=self.current_step if self.training_args.log_with == "tensorboard" else None,
        )

    def train(self, save_directory=""):
        self.model.train()
        accu_loss = []
        total_step = 0

        for epoch in range(int(self.training_args.num_train_epochs)):

            batches = enumerate(self.dataloader)
            if self.accelerator.is_main_process:
                batches = tqdm.tqdm(batches, total=len(self.dataloader), colour="blue", desc=f'epoch: {epoch}')

            for step, batch in batches:

                with self.accelerator.accumulate(self.model):
                    model_kwargs = {
                        'graph_embeds': batch['graph_embeds'],
                        'graph_attention_mask': batch['graph_attention_mask'],
                        'rel_position': batch['rel_position'],
                        'edge_attr': batch['edge_attr'],
                        'edge_type': batch['edge_type'],
                        'input_ids': batch['input_ids'],
                        'prompt_answer_ids': batch['prompt_answer_ids'],
                        'labels': batch['labels'],
                        'attention_mask': batch['attention_mask'],
                        'text_length': batch['text_length'],
                        # "return_dict": True
                    }

                    self.optimizer.zero_grad()
                    # with torch.autograd.detect_anomaly():
                    output = self.model(**model_kwargs)
                    loss = output['loss']
                    self.accelerator.backward(loss)

                    self.accelerator.clip_grad_norm_(self.optimizer.param_groups[0]['params'], self.training_args.max_grad_norm)
                    self.accelerator.clip_grad_norm_(self.optimizer.param_groups[1]['params'], self.training_args.max_grad_norm)
                    self.accelerator.clip_grad_norm_(self.optimizer.param_groups[2]['params'], self.training_args.max_grad_norm)

                    self.optimizer.step()
                    self.lr_scheduler.step()

                    all_losses = self.accelerator.gather(loss)

                    # if torch.any(torch.gt(all_losses, 10.)):
                    #     print_rank_0(f'Loss is too large to optimize, '
                    #           f'loss = {all_losses}, step = {step}, epoch = {epoch}, '
                    #           f'processor index = {self.accelerator.process_index}', flush=True)
                    accu_loss.append(np.mean(all_losses.to(torch.float).detach().cpu().numpy()))

                if (self.training_args.gradient_accumulation_steps <= 1 or
                        (step + 1) % self.training_args.gradient_accumulation_steps == 0):

                    # for i, param_group in enumerate(self.optimizer.param_groups):
                    #     print_rank_0(f"Parameter group {i}: Learning rate = {param_group['lr']}")
                    print_rank_0(f'loss = {float(np.mean(accu_loss)):.5}, lr = {float(np.asarray(self.lr_scheduler.get_lr())[0]):.3}')

                    self.log_stats({
                        'loss': np.mean(accu_loss),
                        'learning rate': float(np.asarray(self.lr_scheduler.get_lr())[0])
                    })
                    accu_loss.clear()

                    total_step += 1
                    # if total_step % self.training_args.checkpoint_step == 0 and save_directory:
                    #     self.save_model(output_dir=f'{save_directory}/step_{str(100000 + total_step)[1:]}')

                    # if total_step % self.training_args.evaluation_step == 0:
                    #     self.evaluation(total_step, is_last=False)

            # self.evaluation(total_step, is_last=(epoch == self.config.num_epochs - 1))
            if save_directory:
                self.save_model(output_dir=f'{save_directory}/epoch_{str(epoch + 1000)[1:]}')
    def evaluation(self):
        for name, dataloader in zip(self.names, self.dataloader):
            eval_output, eval_label = self.evaluation_dataset(dataloader)

            result_dir = os.path.join(self.training_args.result_dir, name)

            if self.accelerator.is_local_main_process:
                if not os.path.exists(result_dir):
                    os.makedirs(result_dir, exist_ok=True)

                eval_pred, eval_decode_label = output_decode(eval_output, eval_label, self.tokenizer)
                with open(os.path.join(result_dir, 'preds.txt'), 'w') as f:
                    json.dump(eval_pred, f)
                with open(os.path.join(result_dir, 'labels.txt'), 'w') as f:
                    json.dump(eval_decode_label, f)

    def evaluation_dataset(self, dataloader):
        self.model.eval()
        samples_seen = 0
        eval_output = []
        eval_label = []

        batches = enumerate(dataloader)
        if self.accelerator.is_main_process:
            batches = tqdm.tqdm(batches, total=len(dataloader), colour="green")

        for step, batch in batches:
            with torch.no_grad():
                model_kwargs = {
                    'graph_embeds': batch['graph_embeds'],
                    'graph_attention_mask': batch['graph_attention_mask'],
                    'rel_position': batch['rel_position'],
                    'edge_attr': batch['edge_attr'],
                    'edge_type': batch['edge_type'],
                    'input_ids': batch['input_ids'],
                    'prompt_answer_ids': batch['prompt_answer_ids'],
                    'labels': batch['labels'],
                    'attention_mask': batch['attention_mask'],
                    'text_length': batch['text_length'],
                    'prompt_attention_mask': batch['prompt_attention_mask']
                }
                torch.use_deterministic_algorithms(True, warn_only=True)
                results = self.model.generate(**model_kwargs)
                results = self.accelerator.pad_across_processes(results, dim=1, pad_index=self.tokenizer.pad_token_id)
                results_gathered = self.accelerator.gather(results).cpu().numpy()

                labels = self.accelerator.pad_across_processes(
                    batch['labels'],
                    dim=1,
                    pad_index=self.tokenizer.pad_token_id)
                labels_gathered = self.accelerator.gather(labels).cpu().numpy()

                if self.accelerator.num_processes > 1:
                    if step == len(dataloader) - 1:
                        results_gathered = results_gathered[
                                                    : len(dataloader.dataset) - samples_seen]
                        labels_gathered = labels_gathered[
                                                    : len(dataloader.dataset) - samples_seen]
                    else:
                        samples_seen += len(results_gathered)
                labels_gathered = np.where(labels_gathered != -100, labels_gathered, self.tokenizer.pad_token_id)
                self.accelerator.print(f"label={self.tokenizer.batch_decode(labels_gathered, skip_special_tokens=True)}, pred={self.tokenizer.batch_decode(results_gathered, skip_special_tokens=True)}")

                eval_output.append(results_gathered)
                eval_label.append(labels_gathered)

        self.model.train()

        return eval_output, eval_label

    def save_model(self, output_dir):
        import datetime
        start = datetime.datetime.now()
        state_dict = self.accelerator.get_state_dict(self.model)

        if self.accelerator.is_main_process:
            # for k, v in state_dict.items():
            #     if k.startswith('projector'):
            #         print_rank_0(k, v)
            print_rank_0(f'get state dict time cost {datetime.datetime.now() - start}')
            unwrap_model = self.accelerator.unwrap_model(model=self.model)
            unwrap_model.save_model(output_dir, state_dict=state_dict)
        self.accelerator.wait_for_everyone()