import os
import torch
import torch.nn as nn

from torch.utils.data import Sampler

import dataclasses
from einops import rearrange
import json
from typing import Dict, List, Optional, Union
import numpy as np
import gcsfs
from google.cloud import storage
import io
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler

# from bifrost.train.trainer_old import Trainer, logger
from transformers.trainer import Trainer, logger

from transformers.utils import is_torch_tpu_available, is_sagemaker_mp_enabled
from transformers.trainer_utils import has_length
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    from torch_xla.core.xla_model import xrt_world_size, all_reduce

from ezcolorlog import root_logger as logger

from bifrost.utils import IS_XLA_AVAILABLE, count_params


from packaging import version
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
from typing import List, Optional

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.utils import is_apex_available
if is_apex_available():
    from apex import amp

import random
fs = gcsfs.GCSFileSystem(project='multiflow-440903') ###

HOME_DIR = os.path.expanduser("~") + "/"
print("HOME_DIR = ", HOME_DIR)


def maybe_zero_3(param, ignore_status=False, name=None):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                print(name, 'no ignore status')
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
    return to_return


def split_to_even_chunks(indices, lengths, num_chunks):
    """
    Split a list of indices into `chunks` chunks of roughly equal lengths.
    """

    if len(indices) % num_chunks != 0:
        return [indices[i::num_chunks] for i in range(num_chunks)]

    num_indices_per_chunk = len(indices) // num_chunks

    chunks = [[] for _ in range(num_chunks)]
    chunks_lengths = [0 for _ in range(num_chunks)]
    for index in indices:
        shortest_chunk = chunks_lengths.index(min(chunks_lengths))
        chunks[shortest_chunk].append(index)
        chunks_lengths[shortest_chunk] += lengths[index]
        if len(chunks[shortest_chunk]) == num_indices_per_chunk:
            chunks_lengths[shortest_chunk] = float("inf")

    return chunks


def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    assert all(l != 0 for l in lengths), "Should not have zero length."
    if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
        # all samples are in the same modality
        return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
    mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
    lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])

    mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
    lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
    megabatch_size = world_size * batch_size
    mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
    lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]

    last_mm = mm_megabatches[-1]
    last_lang = lang_megabatches[-1]
    additional_batch = last_mm + last_lang
    megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
    megabatch_indices = torch.randperm(len(megabatches), generator=generator)
    megabatches = [megabatches[i] for i in megabatch_indices]

    if len(additional_batch) > 0:
        megabatches.append(sorted(additional_batch))

    return [i for megabatch in megabatches for i in megabatch]


def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
    indices = torch.randperm(len(lengths), generator=generator)
    megabatch_size = world_size * batch_size
    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
    megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]

    return [i for megabatch in megabatches for batch in megabatch for i in batch]


class LengthGroupedSampler(Sampler):
    r"""
    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
    keeping a bit of randomness.
    """

    def __init__(
        self,
        batch_size: int,
        world_size: int,
        lengths: Optional[List[int]] = None,
        generator=None,
        group_by_modality: bool = False,
    ):
        
        if lengths is None:
            raise ValueError("Lengths must be provided.")

        self.batch_size = batch_size
        self.world_size = world_size
        self.lengths = lengths
        self.generator = generator
        self.group_by_modality = group_by_modality


    def __len__(self):
        return len(self.lengths)

    def __iter__(self):
        if self.group_by_modality:
            indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
        else:
            indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
        return iter(indices)


def _fetch_gradients(optimizer, param_to_name, selected_module_names):
    gradients = []
    for param_group in optimizer.param_groups:
        for group, params in param_group.items():
            if group == 'params':
                for p in params:
                    # Use the mapping to get the module name
                    module_name = param_to_name.get(p, "")
                    # Check if the module name matches your criteria
                    if isinstance(p, torch.Tensor) and p.grad is not None and any(selected_name in module_name for selected_name in selected_module_names):
                        p.grad = p.grad.to(torch.float32)
                        gradients.append(p.grad.data)
    return gradients


REDUCE_SUM = 'sum'
def reduce_gradients(optimizer, param_to_name, selected_module_names, groups=None, pin_layout=True):
    count = xrt_world_size()
    if count > 1:
        gradients = _fetch_gradients(optimizer, param_to_name, selected_module_names)
        all_reduce(
            REDUCE_SUM,
            gradients,
            scale=1.0 / count,
            groups=groups,
            pin_layout=pin_layout)

def map_params_to_module_names(model_list):
    param_to_name = {}
    for model in model_list:
        for module_name, module in model.named_modules():
            for param_name, param in module.named_parameters(recurse=False):
                param_to_name[param] = f"{module_name}.{param_name}"
    return param_to_name


@torch.no_grad()
def prepare_inputs_and_labels(
        pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor],
        texts: Union[str, str],
        min_masking_rate: float = 0.0,
        is_train: bool = True,
):
    
    image_tokens = vq_model.get_code(pixel_values_or_image_ids)
    image_tokens = image_tokens + len(uni_prompting.text_tokenizer)

    # create MLM mask and labels
    input_ids, labels, loss_weight, mask_prob = mask_or_random_replace_tokens(
        image_tokens,
        mask_id,
        config,
        mask_schedule=mask_schedule,
        is_train=is_train,
    )
    input_ids, masks, labels = uni_prompting((texts, input_ids, labels), 't2i')

    return input_ids, labels, mask_prob, image_tokens




class BifrostTrainer(Trainer):

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        return super()._get_train_sampler()

    def initialize_task_specific_loss(self):
        self.loss_t2i = torch.tensor(0.0).to(self.args.device)
        self.loss_lm = torch.tensor(0.0).to(self.args.device)
        self.loss_mmu = torch.tensor(0.0).to(self.args.device)
        self.log_task_specific_loss = True

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:

        model.train()
        inputs = self._prepare_inputs(inputs)

        inputs['batch_size_t2i'] = self.args.batch_size_t2i
        inputs['batch_size_lm'] = self.args.batch_size_lm
        inputs['batch_size_mmu'] = self.args.batch_size_mmu

        inputs['label_smoothing'] = self.args.label_smoothing
        inputs['max_seq_length'] = self.args.max_seq_length 
        inputs['num_visual_gen_tokens'] = self.args.num_visual_gen_tokens 
        inputs['precise_prompt_mask'] = self.args.precise_prompt_mask
        inputs['add_timestep_token'] = self.args.add_timestep_token
        inputs['use_discrete_visual_tokenizer'] = self.args.use_discrete_visual_tokenizer
        inputs['min_masking_rate'] = 0.0
        inputs['log_task_specific_loss'] = self.args.log_task_specific_loss
        inputs['add_vision_branch'] = self.args.add_vision_branch 
        inputs['add_vision_branch_reuse_layernorm'] = self.args.add_vision_branch_reuse_layernorm
        inputs['use_lora'] = self.args.use_lora
        inputs['skip_text_part2'] = self.args.skip_text_part2
        inputs['add_vision_gen_mask_token'] = self.args.add_vision_gen_mask_token
        inputs['add_vision_soi_eoi_tokens'] = self.args.add_vision_soi_eoi_tokens
        inputs['add_vision_soi_token'] = self.args.add_vision_soi_token
        inputs['vision_head_type'] = self.args.vision_head_type
        inputs['vision_pos_emb_type'] = self.args.vision_pos_emb_type

        inputs['t2i_coeff'] = self.args.t2i_coeff
        inputs['lm_coeff'] = self.args.lm_coeff
        inputs['mmu_coeff'] = self.args.mmu_coeff
        

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        return loss.detach() / self.args.gradient_accumulation_steps


    def compute_loss(self, model, inputs, return_outputs=False):

        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        if hasattr(self, 'log_task_specific_loss') and self.log_task_specific_loss:
            if 'loss_t2i' in outputs:
                self.loss_t2i = outputs['loss_t2i'].detach().item()
            if 'loss_lm' in outputs:
                self.loss_lm = outputs['loss_lm'].detach().item()
            if 'loss_mmu' in outputs:
                self.loss_mmu = outputs['loss_mmu'].detach().item()


        return (loss, outputs) if return_outputs else loss


    def create_optimizer(self):
        """
        adapted to show-o optimizer setups
        """

        if is_sagemaker_mp_enabled(): # False
            return super().create_optimizer()
        opt_model = self.model

        if self.optimizer is None:

            no_decay = ["bias", "layernorm", "rmsnorm", "embed_tokens"]
            no_decay_params = {"params": [], "weight_decay": 0.0, 'lr': self.args.learning_rate}
            requires_grad_params = {"params": []}
            no_grad_params = {"params": []}


            def set_frozen_with_grad(n, p):
                requires_grad_params['params'].append(p)
                p.requires_grad = True 
                print("FROZEN WITH GRAD:", n, p.numel())

            def set_frozen_without_grad(n, p):
                no_grad_params['params'].append(p)
                p.requires_grad = False 
                print("FROZEN W/O GRAD:", n, p.numel())

            def set_trainable(n, p):
                p.requires_grad = True 
                if not any(nd in n for nd in no_decay):
                    # weight_decay_params['params'].append(p)
                    no_decay_params['params'].append(p)
                else:
                    no_decay_params['params'].append(p)
                print("TRAINABLE:", n, p.numel())

            def set_zero_lr(n, p):
                p.requires_grad = True 
                zero_lr_params['params'].append(p)
                print("ZERO LR:", n, p.numel())
            
            for n, p in opt_model.named_parameters():

                # vlm backbone
                if 'vision_language_model' in n:

                    if self.args.fully_trainable:
                        set_trainable(n, p)
                    elif 'lora_A' in n or 'lora_B' in n or 'vision_branch_' in n:
                        set_trainable(n, p)
                    else:
                        set_frozen_without_grad(n, p)
                
                # discrete tokenizer / VAE
                elif 'vq_model' in n or 'vision_gen_vae_model' in n:
                    set_frozen_without_grad(n, p)

                # discrete tokenizer aligner
                elif 'vision_gen_aligner' in n or 'vision_gen_head' in n or 'visual_embed_tokens' in n:
                    set_trainable(n, p)

                # vae aligner
                elif 'vision_gen_dec_aligner' in n or 'vision_gen_dec_model' in n or 'vision_gen_enc_aligner' in n or 'vision_gen_enc_model' in n:
                    set_trainable(n, p)

                # dit-style vae aligner 
                elif 'vision_branch_x_embedder' in n or 'vision_branch_proj_out' in n or 'vision_branch_time_proj' in n or 'vision_branch_time_embed' in n: # TODO: note: no trainable params in vision_branch_time_proj, can be deleted
                    set_trainable(n, p)
                # mar-style params
                elif 'mask_token' in n or 'diffloss' in n or 'vision_head' in n or 'learnable_pos_emb' in n or 'learnable_2d_query_tokens' in n or 'vision_embed_proj_in' in n or 'up_sampling_block' in n or 'down_sampling_block' in n: 
                    set_trainable(n, p)
                elif 'controlnet' in n:
                    set_trainable(n, p)
                elif 'diffusion_decoder' in n and 'controlnet' not in n:
                    set_frozen_without_grad(n, p)
                else:
                    raise ValueError(f"There should not be parameters under this else branch: {n}")

                    


            for n, p in opt_model.named_parameters():
                # if 'vis' in n:
                print(n.replace("_fsdp_wrapped_module._fpw_module.", "").replace("_fsdp_shard_FSDP_SHARD_SEPARATOR__fsdp_wrapped_module_FSDP_SHARD_SEPARATOR__fpw_module_FSDP_SHARD_SEPARATOR_", "").replace("FSDP_SHARD_SEPARATOR_", ""), p.requires_grad)

            count_params(no_decay_params['params'], text='no_decay_params')
            count_params(requires_grad_params['params'], text = 'requires_grad_params')
            count_params(no_grad_params['params'], text = 'no_grad_params')

            optimizer_grouped_parameters = [no_decay_params]
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
            del optimizer_kwargs['lr'] ##

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer
    

    def remove_prefix(text, prefix='gs://us-central2-storage/'):
        if prefix in text:
            return text.replace(prefix, '')
        return text
    
    def _save_checkpoint(self, model, trial, metrics=None):

        if len(self.args.fsdp) > 0: # use FSDP
            from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

            # Names of files
            TRAINING_ARGS_NAME = "training_args.bin"
            WEIGHTS_NAME = "pytorch_model.bin"
            SCHEDULER_NAME = "scheduler.pt"
            TRAINER_STATE_NAME = "trainer_state.json"

            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
            run_dir = self._get_output_dir(trial=trial)
            output_dir = os.path.join(run_dir, checkpoint_folder)
            logger.info(f"Saving model checkpoint to {output_dir}")

            model = self.model
            import torch_xla.core.xla_model as xm

            if len(self.args.fsdp) > 0: # use FSDP
                rank = xm.get_ordinal()
                world_size = xm.xrt_world_size()
                # Name of files to save
                SHARD_NAME = f'weights_rank-{rank:08d}-of-{world_size:08d}-{WEIGHTS_NAME}'
                SHARD_NAME_OPT = f'opt_rank-{rank:08d}-of-{world_size:08d}-{WEIGHTS_NAME}'
                RNG_NAME = f'rng_rank-{rank:08d}-of-{world_size:08d}-rng.pth'
                _master_only = False
            else: # not use FSDP
                _rank = 0
                _world_size = 0
                # Name of files to save
                SHARD_NAME = f'weights_rank-{_rank:08d}-of-{_world_size:08d}-{WEIGHTS_NAME}'
                SHARD_NAME_OPT = f'opt_rank-{_rank:08d}-of-{_world_size:08d}-{WEIGHTS_NAME}'
                RNG_NAME = f'rng_rank-{_rank:08d}-of-{_world_size:08d}-rng.pth'
                _master_only = True


            # Path of files to save
            SHARD_NAME_PATH = os.path.join(output_dir, SHARD_NAME)
            SHARD_NAME_OPT_PATH = os.path.join(output_dir, SHARD_NAME_OPT)
            LR_PATH = os.path.join(output_dir, SCHEDULER_NAME)
            TRAIN_ARGS_PATH = os.path.join(output_dir, TRAINING_ARGS_NAME)
            TRAINER_STATE_NAME_PATH = os.path.join(output_dir, TRAINER_STATE_NAME)
            RNG_PATH = os.path.join(output_dir, RNG_NAME)
            lr_scheduler_state_dict = self.lr_scheduler.state_dict()

            # Final form of model and opt
            if len(self.args.fsdp) > 0: # use FSDP
                ckpt = {
                    'model': self.model.state_dict(),
                    'shard_metadata': self.model.get_shard_metadata()
                }
                opt_ckpt = {
                    'optimizer_state' : self.optimizer.state_dict(),
                    'shard_metadata': self.model.get_shard_metadata()
                }
            else: # not use FSDP
                ckpt = {'model': self.model.state_dict()}
                opt_ckpt = {'optimizer_state' : self.optimizer.state_dict()}

            # Saving model shards
            with fs.open(SHARD_NAME_PATH, 'wb') as f:
                xm.save(ckpt, f, master_only=_master_only)


            # Saving optimizer shards
            with fs.open(SHARD_NAME_OPT_PATH, 'wb') as f:
                xm.save(opt_ckpt, f, master_only=_master_only)

            # saving lr scheduler and train state json
            if xm.is_master_ordinal(local=False):
                with fs.open(LR_PATH, 'wb') as f:
                    xm.save(lr_scheduler_state_dict, f, master_only=True)

                json_string = json.dumps(dataclasses.asdict(self.state), indent=2, sort_keys=True) + "\n"
                with fs.open(TRAINER_STATE_NAME_PATH, 'w') as f:
                    f.write(json_string)

            rng_states = {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "cpu": torch.random.get_rng_state(),
            }
            rng_states["xla"] = xm.get_rng_state()
            with fs.open(RNG_PATH, 'wb') as f:
                torch.save(rng_states, f)

        else:
            super(BifrostTrainer, self)._save_checkpoint(model, trial)


    def get_train_dataloader(self) -> DataLoader:
        out = super().get_train_dataloader()
        if is_torch_tpu_available():
            return out._loader
        else:
            return out
