import sys
import os
project_root = os.path.join(os.path.dirname(__file__), "..")
project_root = os.path.abspath(project_root)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import matplotlib
matplotlib.use('Agg')
from main.utils import prepare_images_for_saving, draw_valued_array, draw_probability_histogram
from main.sd_image_dataset import SDImageDatasetLMDB
from transformers import CLIPTokenizer, AutoTokenizer
from accelerate.utils import ProjectConfiguration
from diffusers.optimization import get_scheduler
from main.utils import SDTextDataset, cycle 
from main.sd_unified_model import SDUniModel
from accelerate.utils import set_seed,FullyShardedDataParallelPlugin
from accelerate import Accelerator
import argparse 
import shutil 
import wandb 
import torch 
import time 
import os
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.configuration_utils")


class Trainer:
    def __init__(self, args):
        self.args = args

        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True 

        accelerator_project_config = ProjectConfiguration(logging_dir=args.log_path)
        accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision="no",
            # log_with="wandb",
            project_config=accelerator_project_config,
            kwargs_handlers=None,
            # fsdp_plugin=fsdp_plugin
        )
        set_seed(args.seed + accelerator.process_index)

        print(accelerator.state)

        self.sdxl = args.sdxl 
        
        tokenizer = CLIPTokenizer.from_pretrained(
            args.model_id, subfolder="tokenizer"
        )
        uncond_input_ids = tokenizer(
            [""], max_length=tokenizer.model_max_length, 
            return_tensors="pt", padding="max_length", truncation=True
        ).input_ids.to(accelerator.device)

        dataset = SDTextDataset(args.train_prompt_path, tokenizer, is_sdxl=False)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
        dataloader = accelerator.prepare(dataloader)
        self.dataloader = cycle(dataloader)

        
        tokenizer_one = AutoTokenizer.from_pretrained(
            args.model_id, subfolder="tokenizer", revision=args.revision, use_fast=False
        )

        tokenizer_two = AutoTokenizer.from_pretrained(
            args.model_id, subfolder="tokenizer_2", revision=args.revision, use_fast=False
        )

        # also load the training dataset images, this will be useful for GAN loss 
        real_dataset = SDImageDatasetLMDB(
            args.real_image_path, 
            is_sdxl=True,
            tokenizer_one=tokenizer_one,
            tokenizer_two=tokenizer_two
        )

        real_dataloader = torch.utils.data.DataLoader(
            real_dataset, num_workers=128, 
            batch_size=args.batch_size, shuffle=True, 
            drop_last=True
        )
        real_dataloader = accelerator.prepare(real_dataloader)
        self.real_dataloader = cycle(real_dataloader)


    def train_one_step(self):
        try:
            real_train_dict = next(self.real_dataloader) 
            out2 = next(self.dataloader)
        except Exception as e:
            breakpoint()
        return real_train_dict

    def train(self):
        from tqdm import tqdm
        prev_real_train_dict = None
        for index in tqdm(range(0, 500000//self.args.batch_size)):          
            real_train_dict = self.train_one_step()
            if prev_real_train_dict is not None:
                for k,v in prev_real_train_dict.items():
                    assert type(prev_real_train_dict[k]) == type(real_train_dict[k])
            if index == 0:
                for k,v in real_train_dict.items():
                    print(f'{k}:{type(v)}')
            prev_real_train_dict = deepcopy(real_train_dict)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument("--output_path", type=str, default="/mnt/localssd/test_stable_diffusion_coco")
    parser.add_argument("--log_path", type=str, default="/mnt/localssd/log_stable_diffusion_coco")
    parser.add_argument("--train_iters", type=int, default=1000000)
    parser.add_argument("--log_iters", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--resolution", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--initialie_generator", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--wandb_entity", type=str)
    parser.add_argument("--wandb_project", type=str)
    parser.add_argument("--wandb_iters", type=int, default=100)
    parser.add_argument("--wandb_name", type=str, required=True)
    parser.add_argument("--max_grad_norm", type=float, default=10.0, help="max grad norm for network")
    parser.add_argument("--warmup_step", type=int, default=500, help="warmup step for network")
    parser.add_argument("--min_step_percent", type=float, default=0.02, help="minimum step percent for training")
    parser.add_argument("--max_step_percent", type=float, default=0.98, help="maximum step percent for training")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--use_fp16", action="store_true")
    parser.add_argument("--num_train_timesteps", type=int, default=1000)
    parser.add_argument("--ckpt_only_path", type=str, default=None, help="checkpoint (no optimizer state) only path")
    parser.add_argument("--train_prompt_path", type=str)
    parser.add_argument("--latent_resolution", type=int, default=64)
    parser.add_argument("--real_guidance_scale", type=float, default=6.0)
    parser.add_argument("--fake_guidance_scale", type=float, default=1.0)
    parser.add_argument("--grid_size", type=int, default=2)
    parser.add_argument("--no_save", action="store_true", help="don't save ckpt for debugging only")
    parser.add_argument("--cache_dir", type=str, default="/mnt/localssd/cache")
    parser.add_argument("--log_loss", action="store_true", help="log loss at every iteration")
    parser.add_argument("--num_workers", type=int, default=12)
    parser.add_argument("--latent_channel", type=int, default=4)
    parser.add_argument("--max_checkpoint", type=int, default=150)
    parser.add_argument("--dfake_gen_update_ratio", type=int, default=1)
    parser.add_argument("--generator_lr", type=float)
    parser.add_argument("--guidance_lr", type=float)
    parser.add_argument("--cls_on_clean_image", action="store_true")
    parser.add_argument("--gen_cls_loss", action="store_true")
    parser.add_argument("--gen_cls_loss_weight", type=float, default=0)
    parser.add_argument("--guidance_cls_loss_weight", type=float, default=0)
    parser.add_argument("--sdxl", action="store_true")
    parser.add_argument("--fsdp", action="store_true")
    parser.add_argument("--generator_ckpt_path", type=str)
    parser.add_argument("--conditioning_timestep", type=int, default=999)
    parser.add_argument("--tiny_vae", action="store_true")
    parser.add_argument("--gradient_checkpointing", action="store_true", help="apply gradient checkpointing for dfake and generator. this might be a better option than FSDP")
    parser.add_argument("--dm_loss_weight", type=float, default=1.0)

    parser.add_argument("--denoising", action="store_true", help="train the generator for denoising")
    parser.add_argument("--denoising_timestep", type=int, default=1000)
    parser.add_argument("--num_denoising_step", type=int, default=1)
    parser.add_argument("--denoising_loss_weight", type=float, default=1.0)

    parser.add_argument("--diffusion_gan", action="store_true")
    parser.add_argument("--diffusion_gan_max_timestep", type=int, default=0)
    parser.add_argument("--revision", type=str)

    parser.add_argument("--real_image_path", type=str)
    parser.add_argument("--gan_alone", action="store_true", help="only use the gan loss without dmd")
    parser.add_argument("--backward_simulation", action="store_true")

    parser.add_argument("--generator_lora", action="store_true")
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=float, default=64)
    parser.add_argument("--lora_dropout", type=float, default=0.0)
    
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    assert args.gradient_accumulation_steps == 1, "grad accumulation not supported yet"

    assert not (args.fsdp and args.gradient_checkpointing), "currently, we don't support both options. open an issue for details."

    assert args.wandb_iters % args.dfake_gen_update_ratio == 0, "wandb_iters should be a multiple of dfake_gen_update_ratio"

    return args 

if __name__ == "__main__":
    args = parse_args()

    trainer = Trainer(args)

    trainer.train()