import csv
import datetime
import inspect
import logging
import os
import random
import re
from copy import deepcopy
from types import MethodType
from typing import Dict

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from decord import VideoReader
from diffusers import (AutoencoderKL, DDIMScheduler, MotionAdapter,
                       UNet2DConditionModel, UNetMotionModel)
from diffusers.optimization import get_scheduler
from diffusers.pipelines import AnimateDiffPipeline
from diffusers.utils import export_to_gif
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from modelscope import snapshot_download
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import RandomSampler
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from swift import LoRAConfig, Swift, get_logger, push_to_hub
from swift.aigc.utils import AnimateDiffArguments
from swift.utils import get_dist_setting, get_main, is_dist

logger = get_logger()


class AnimateDiffDataset(Dataset):

    VIDEO_ID = 'videoid'
    NAME = 'name'
    CONTENT_URL = 'contentUrl'

    def __init__(
        self,
        csv_path,
        video_folder,
        sample_size=256,
        sample_stride=4,
        sample_n_frames=16,
        dataset_sample_size=10000,
    ):
        print(f'loading annotations from {csv_path} ...')
        with open(csv_path, 'r') as csvfile:
            self.dataset = list(csv.DictReader(csvfile))
        dataset = []
        for d in tqdm(self.dataset):
            content_url = d[self.CONTENT_URL]
            file_name = content_url.split('/')[-1]
            if os.path.isfile(os.path.join(video_folder, file_name)):
                dataset.append(d)
            if dataset_sample_size is not None and len(
                    dataset) > dataset_sample_size:
                break

        self.dataset = dataset
        self.length = len(self.dataset)
        print(f'data scale: {self.length}')

        self.video_folder = video_folder
        self.sample_stride = sample_stride
        self.sample_n_frames = sample_n_frames

        sample_size = tuple(sample_size) if not isinstance(
            sample_size, int) else (sample_size, sample_size)
        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size[0]),
            transforms.CenterCrop(sample_size),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])

    def get_batch(self, idx):
        video_dict: Dict[str, str] = self.dataset[idx]
        name = video_dict[self.NAME]

        content_url = video_dict[self.CONTENT_URL]
        file_name = content_url.split('/')[-1]
        video_dir = os.path.join(self.video_folder, file_name)
        video_reader = VideoReader(video_dir)
        video_length = len(video_reader)

        clip_length = min(video_length,
                          (self.sample_n_frames - 1) * self.sample_stride + 1)
        start_idx = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(
            start_idx,
            start_idx + clip_length - 1,
            self.sample_n_frames,
            dtype=int)

        pixel_values = torch.from_numpy(
            video_reader.get_batch(batch_index).asnumpy()).permute(
                0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader
        return pixel_values, name

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, name = self.get_batch(idx)
                break

            except Exception as e:
                logger.error(f'Error loading dataset batch: {e}')
                idx = random.randint(0, self.length - 1)

        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(pixel_values=pixel_values, text=name)
        return sample


def save_videos_grid(videos: torch.Tensor,
                     path: str,
                     rescale=False,
                     n_rows=6,
                     duration=4):
    import imageio
    videos = rearrange(videos, 'b c t h w -> t b c h w')
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, duration=duration)


def animatediff_sft(args: AnimateDiffArguments) -> None:
    # Initialize distributed training
    if is_dist():
        _, local_rank, num_processes, _ = get_dist_setting()
        global_rank = dist.get_rank()
    else:
        local_rank = 0
        global_rank = 0
        num_processes = 1
    is_main_process = global_rank == 0

    global_seed = args.seed + global_rank
    torch.manual_seed(global_seed)

    # Logging folder
    folder_name = datetime.datetime.now().strftime('ad-%Y-%m-%dT%H-%M-%S')
    output_dir = os.path.join(args.output_dir, folder_name)

    *_, config = inspect.getargvalues(inspect.currentframe())

    if is_main_process and args.use_wandb:
        import wandb
        wandb.init(project='animatediff', name=folder_name, config=config)

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
    )

    # Handle the output folder creation
    if is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/samples', exist_ok=True)
        os.makedirs(f'{output_dir}/sanity_check', exist_ok=True)
        os.makedirs(f'{output_dir}/checkpoints', exist_ok=True)

    with open(args.validation_prompts_path, 'r') as f:
        validation_data = f.readlines()

    # Load scheduler, tokenizer and models.
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=args.num_train_timesteps,
        beta_start=args.beta_start,
        beta_end=args.beta_end,
        beta_schedule=args.beta_schedule,
        steps_offset=args.steps_offset,
        clip_sample=args.clip_sample,
    )
    if not os.path.exists(args.model_id_or_path):
        pretrained_model_path = snapshot_download(
            args.model_id_or_path, revision=args.model_revision)
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder='vae')
    tokenizer = CLIPTokenizer.from_pretrained(
        pretrained_model_path, subfolder='tokenizer')
    text_encoder = CLIPTextModel.from_pretrained(
        pretrained_model_path, subfolder='text_encoder')

    motion_adapter = None
    if args.motion_adapter_id_or_path is not None:
        if not os.path.exists(args.motion_adapter_id_or_path):
            args.motion_adapter_id_or_path = snapshot_download(
                args.motion_adapter_id_or_path,
                revision=args.motion_adapter_revision)
        motion_adapter = MotionAdapter.from_pretrained(
            args.motion_adapter_id_or_path)
    unet: UNetMotionModel = UNetMotionModel.from_unet2d(
        UNet2DConditionModel.from_pretrained(
            pretrained_model_path, subfolder='unet'),
        motion_adapter=motion_adapter,
        load_weights=True,
    )

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Set unet trainable parameters
    unet.requires_grad_(False)
    for name, param in unet.named_parameters():
        if re.fullmatch(args.trainable_modules, name):
            param.requires_grad = True

    # Preparing LoRA
    if args.sft_type == 'lora':
        if args.motion_adapter_id_or_path is None:
            raise ValueError(
                'No AnimateDiff weight found, Please do not use LoRA.')
        lora_config = LoRAConfig(
            r=args.lora_rank,
            target_modules=args.trainable_modules,
            lora_alpha=args.lora_alpha,
            lora_dtype=args.lora_dtype,
            lora_dropout=args.lora_dropout_p)
        unet = Swift.prepare_model(unet, lora_config)
        logger.info(f'lora_config: {lora_config}')

    trainable_params = list(
        filter(lambda p: p.requires_grad, unet.parameters()))
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )

    if is_main_process:
        print(f'trainable params number: {len(trainable_params)}')
        print(
            f'trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M'
        )

    # Enable xformers
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError(
                'xformers is not available. Make sure it is installed correctly'
            )

    # Enable gradient checkpointing
    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Move models to GPU
    vae.to(local_rank)
    text_encoder.to(local_rank)

    # Get the training dataset
    train_dataset = AnimateDiffDataset(
        csv_path=args.csv_path,
        video_folder=args.video_folder,
        sample_size=args.sample_size,
        sample_stride=args.sample_stride,
        sample_n_frames=args.sample_n_frames,
        dataset_sample_size=args.dataset_sample_size,
    )

    if not is_dist():
        sampler = RandomSampler(train_dataset)
    else:
        sampler = DistributedSampler(
            train_dataset,
            num_replicas=num_processes,
            rank=global_rank,
            shuffle=True,
            seed=global_seed)

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=args.dataloader_num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # Get the training iteration
    max_train_steps = args.num_train_epochs * len(train_dataloader)
    print(f'max_train_steps: {max_train_steps}')

    # Scheduler
    lr_scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=int(args.warmup_ratio * max_train_steps)
        // args.gradient_accumulation_steps,
        num_training_steps=max_train_steps // args.gradient_accumulation_steps,
    )

    unet.to(local_rank)
    if is_dist():
        unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)

    num_train_epochs = args.num_train_epochs

    # Train!
    total_batch_size = args.batch_size * num_processes * args.gradient_accumulation_steps

    if is_main_process:
        logging.info('***** Running training *****')
        logging.info(f'  Num examples = {len(train_dataset)}')
        logging.info(f'  Num Epochs = {num_train_epochs}')
        logging.info(
            f'  Instantaneous batch size per device = {args.batch_size}')
        logging.info(
            f'  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}'
        )
        logging.info(
            f'  Gradient Accumulation steps = {args.gradient_accumulation_steps}'
        )
        logging.info(f'  Total optimization steps = {max_train_steps}')
    global_step = 0
    first_epoch = 0

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(
        range(global_step, max_train_steps), disable=not is_main_process)
    progress_bar.set_description('Steps')

    # Support mixed-precision training
    scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None

    for epoch in range(first_epoch, num_train_epochs):
        if is_dist():
            train_dataloader.sampler.set_epoch(epoch)

        unet.train()

        for step, batch in enumerate(train_dataloader):
            if args.text_dropout_rate > 0:
                batch['text'] = [
                    name if random.random() > args.text_dropout_rate else ''
                    for name in batch['text']
                ]

            # Data batch sanity check
            if epoch == first_epoch and step == 0:
                pixel_values, texts = batch['pixel_values'].cpu(
                ), batch['text']
                pixel_values = rearrange(pixel_values,
                                         'b f c h w -> b c f h w')
                for idx, (pixel_value,
                          text) in enumerate(zip(pixel_values, texts)):
                    pixel_value = pixel_value[None, ...]
                    file_name = '-'.join(text.replace('/', '').split(
                    )[:10]) if not text == '' else f'{global_rank}-{idx}'
                    save_videos_grid(
                        pixel_value,
                        f'{output_dir}/sanity_check/{file_name}.gif',
                        rescale=True)

            # Convert videos to latent space
            pixel_values = batch['pixel_values'].to(local_rank)
            video_length = pixel_values.shape[1]
            with torch.no_grad():
                pixel_values = rearrange(pixel_values,
                                         'b f c h w -> (b f) c h w')
                latents = vae.encode(pixel_values).latent_dist
                latents = latents.sample()
                latents = rearrange(
                    latents, '(b f) c h w -> b c f h w', f=video_length)
                latents = latents * 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]

            # Sample a random timestep for each video
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps, (bsz, ),
                device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise,
                                                      timesteps)

            # Get the text embedding for conditioning
            with torch.no_grad():
                prompt_ids = tokenizer(
                    batch['text'],
                    max_length=tokenizer.model_max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt').input_ids.to(latents.device)
                encoder_hidden_states = text_encoder(prompt_ids)[0]

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == 'epsilon':
                target = noise
            elif noise_scheduler.config.prediction_type == 'v_prediction':
                raise NotImplementedError
            else:
                raise ValueError(
                    f'Unknown prediction type {noise_scheduler.config.prediction_type}'
                )

            # Predict the noise residual and compute loss
            # Mixed-precision training
            with torch.cuda.amp.autocast(enabled=args.mixed_precision):
                model_pred = unet(noisy_latents, timesteps,
                                  encoder_hidden_states).sample
                loss = F.mse_loss(
                    model_pred.float(), target.float(), reduction='mean')

            # Backpropagate
            if args.mixed_precision:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if step % args.gradient_accumulation_steps == 0:
                # Backpropagate
                if args.mixed_precision:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(unet.parameters(),
                                                   args.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(unet.parameters(),
                                                   args.max_grad_norm)
                    optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()

            progress_bar.update(1)
            global_step += 1

            # Wandb logging
            if is_main_process and args.use_wandb:
                wandb.log({'train_loss': loss.item()}, step=global_step)

            # Save checkpoint
            if is_main_process and (global_step % args.save_steps == 0
                                    or step == len(train_dataloader) - 1):
                save_path = os.path.join(output_dir, 'checkpoints')
                if step == len(train_dataloader) - 1:
                    if isinstance(unet, DDP):
                        unet.module.save_pretrained(
                            os.path.join(save_path, 'iter-last'))
                    else:
                        unet.save_pretrained(
                            os.path.join(save_path, 'iter-last'))
                    if args.push_to_hub:
                        push_to_hub(
                            repo_name=args.hub_model_id,
                            output_dir=os.path.join(save_path, 'iter-last'),
                            token=args.hub_token,
                            private=True,
                        )
                    logging.info(
                        f'Saved state to {os.path.join(save_path, "iter-last")} on the last step'
                    )
                else:
                    iter_save_path = os.path.join(save_path,
                                                  f'iter-{global_step}')
                    if isinstance(unet, DDP):
                        unet.module.save_pretrained(iter_save_path)
                    else:
                        unet.save_pretrained(iter_save_path)
                    if args.push_to_hub and args.push_hub_strategy == 'all_checkpoints':
                        push_to_hub(
                            repo_name=args.hub_model_id,
                            output_dir=os.path.join(save_path,
                                                    f'iter-{global_step}'),
                            token=args.hub_token,
                            private=True,
                        )
                    logging.info(
                        f'Saved state to {os.path.join(save_path, f"iter-{global_step}")} (global_step: {global_step})'
                    )

            # Periodically validation
            if is_main_process and global_step % args.eval_steps == 0:

                generator = torch.Generator(device=latents.device)
                generator.manual_seed(global_seed)
                Swift.merge(unet)
                height = args.sample_size
                width = args.sample_size

                def state_dict(self,
                               *args,
                               destination=None,
                               prefix='',
                               keep_vars=False,
                               adapter_name: str = None,
                               **kwargs):
                    state_dict = self.state_dict_origin()
                    return {
                        key.replace('base_layer.', ''): value
                        for key, value in state_dict.items()
                        if 'lora' not in key
                    }

                motion_adapter = MotionAdapter(
                    motion_num_attention_heads=args.motion_num_attention_heads,
                    motion_max_seq_length=args.motion_max_seq_length)

                module = unet if not isinstance(unet, DDP) else unet.module
                motion_adapter.mid_block.motion_modules = deepcopy(
                    module.mid_block.motion_modules)
                motion_adapter.mid_block.motion_modules.state_dict_origin = \
                    motion_adapter.mid_block.motion_modules.state_dict
                motion_adapter.mid_block.motion_modules.state_dict = MethodType(
                    state_dict, motion_adapter.mid_block.motion_modules)
                for db1, db2 in zip(motion_adapter.down_blocks,
                                    module.down_blocks):
                    db1.motion_modules = deepcopy(db2.motion_modules)
                    db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
                    db1.motion_modules.state_dict = MethodType(
                        state_dict, db1.motion_modules)
                for db1, db2 in zip(motion_adapter.up_blocks,
                                    module.up_blocks):
                    db1.motion_modules = deepcopy(db2.motion_modules)
                    db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
                    db1.motion_modules.state_dict = MethodType(
                        state_dict, db1.motion_modules)

                Swift.unmerge(unet)
                validation_pipeline = AnimateDiffPipeline(
                    unet=UNet2DConditionModel.from_pretrained(
                        pretrained_model_path, subfolder='unet'),
                    vae=vae,
                    tokenizer=tokenizer,
                    motion_adapter=motion_adapter,
                    text_encoder=text_encoder,
                    scheduler=noise_scheduler,
                ).to('cuda')
                validation_pipeline.enable_vae_slicing()
                validation_pipeline.enable_model_cpu_offload()

                for idx, prompt in enumerate(validation_data):
                    output = validation_pipeline(
                        prompt=prompt,
                        negative_prompt='bad quality, worse quality',
                        num_frames=args.sample_n_frames,
                        height=height,
                        width=width,
                        guidance_scale=args.guidance_scale,
                        num_inference_steps=args.num_inference_steps,
                        generator=torch.Generator('cpu').manual_seed(
                            global_seed),
                    )
                    frames = output.frames[0]
                    export_to_gif(
                        frames,
                        f'{output_dir}/samples/sample-{global_step}-{idx}.gif')
                unet.train()

            logs = {
                'step_loss': loss.detach().item(),
                'lr': lr_scheduler.get_last_lr()[0]
            }
            progress_bar.set_postfix(**logs)

            if global_step >= max_train_steps:
                break

    if is_dist():
        dist.destroy_process_group()


animatediff_main = get_main(AnimateDiffArguments, animatediff_sft)
