import os
import sys
import torch
import random
import numpy as np
from tap import Tap
from typing import Optional, Union
from collections import OrderedDict

from utils import dist


class Args(Tap):
    model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx
    exp_name: str = 'unitok_large'
    output_dir: str = 'local_output'
    resume_from: str = ''  # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists)
    lpips_path: str = 'external/lpips_with_vgg.pth'
    dino_path: str = 'external/dinov2_vits14_pretrain.pth'
    fid_eval_src: str = ''
    fid_eval_dst: str = ''
    vis_img_dir: str = 'asset/vis_imgs/'
    fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth'
    clip_pretrain_path: str = ''

    # speed-up
    fp16: bool = False  # whether to use FP16
    bf16: bool = True  # whether to use BF16
    tf32: bool = True  # whether to use TensorFloat32
    compile_model: bool = False  # whether to use torch.compile()
    ddp_static: bool = False  # whether to use static graph in DDP
    grad_ckpt: bool = True  # gradient checkpointing
    grad_accu: int = 1  # gradient accumulation
    device: str = 'cpu' # will be set automatically
    dtype: torch.dtype = torch.float32 # will be set automatically

    # data
    train_data: str = None
    val_data: str = None
    dataset_type: str = 'webdataset'
    imagenet_val: str = None
    imagenet_v2: str = None
    subset_ratio: float = 1.0
    img_size: int = 256
    resize_ratio: float = 1.125  # only applicable to 'img' dataset_type
    hflip: bool = False
    workers: int = 8  # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
    train_num_samples: int = 1280_000_000
    train_data_upsampling_factors: str = None
    dataset_resampled: bool = False
    use_aug: bool = False

    # quantizer
    vocab_size: int = 32768
    vocab_width: int = 64
    vocab_norm: bool = True
    vq_beta: float = 0.25  # commitment loss weight
    num_codebooks: int = 8
    quant_proj: str = 'attn'

    # model
    embed_dim: int = 768
    num_query: int = 0
    use_clip_pretrain: bool = False
    patch_size: int = 16
    drop_path: float = 0.1
    text_width: int = 768
    text_heads: int = 12
    text_layers: int = 12
    text_vocab_size: int = 49408
    text_context_length: int = 77

    # CLIP
    local_loss: bool = True
    gather_with_grad: bool = True
    pretrained_clip: str = None
    pretrained_clip_text: str = None
    lock_text: bool = False
    lock_text_unlocked_layers: int = 0
    lock_text_freeze_layer_norm: bool = False
    force_custom_text: bool = False
    force_custom_vision: bool = False
    zeroshot_eval_freq: int = 1

    # discriminator
    dino_depth: int = 12
    dino_kernel_size: int = 9
    disc_norm: str = 'gn'  # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm
    disc_aug_prob: float = 1.0
    disc_specnorm: bool = False
    step_disc_every: int = 1

    # initialization
    vae_init: float = -0.5  # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
    vocab_init: float = -1  # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init)
    disc_init: float = -0.5  # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)

    # optimization
    epoch: int = 1  # number of epochs
    local_bs: int = 64  # batch size per device; if this is specified, --global_bs will be ignored
    vae_local_bs: int = 64 # sub-batch size for vae loss calculation
    global_bs: int = 0  # global batch size (exclusive to --local_bs)
    lr: float = 5e-4  # learning rate
    wd: float = 0.02  # weight decay
    disc_lr: float = 2e-5  # disc lr
    disc_wd: float = 0.2
    grad_clip: float = 10  # <=0 for not using grad clip
    ema: float = 0.9999  # ema ratio
    warmup_iter: int = None
    warmup_ep: float = 0.01  # lr warmup: epochs
    disc_start_ep: float = 0.375  # start using disc loss for VAE after xxx epochs;
    disc_warmup_ep: float = 0.03  # disc loss warm up epochs;
    schedule: str = 'cos'  # lr schedule type
    lr_start_ratio: float = 0.  # lr warmup: initial lr ratio
    lr_end_ratio: float = 0.1  # lr schedule: final lr ratio
    disc_lr_end_ratio: float = 0.1
    custom_lr_multiplier: float = None
    optimizer: str = 'adamw'
    optim_eps: float = 1e-6
    fuse_opt: bool = False  # whether to use fused optimizer
    optim_beta: str = '0.9_0.95'  # beta1, beta2 of optimizer
    disc_optim_beta: str = '0.5_0.9'  # beta1, beta2 of disc optimizer

    # loss
    l1: float = 0.2  # L1 rec loss weight
    l2: float = 1.0  # L2 rec loss weight
    lp: float = 1.0  # lpips loss weight
    lpr: int = 48    # only calculate lpips >= this image resolution
    ld: float = 0.4  # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT
    le: float = 0.0  # VQ entropy loss weight
    lq: float = 1.0
    lc: float = 1.0  # CLIP loss weight
    e_temp: float = 0.01
    gada: int = 1
    bcr: float = 4.  # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0
    bcr_cut: float = 0.2  # cutout ratio (0.5: 50% width)
    dcrit: str = 'hg'  # hg hinge, sp softplus, ln linear

    # wandb log
    report_wandb: bool = True
    wandb_notes: str = None
    run_id: str = None

    # debug
    eval_per_epoch: int = 8
    dbg_unused_param: bool = False
    dbg_nan: bool = False  # 'KEVIN_LOCAL' in os.environ
    seed: int = None
    deterministic: bool = False
    same_seed_for_all_ranks: int = 0  # this is only for distributed sampler


    def seed_everything(self):
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        if self.seed is not None:
            if self.deterministic:
                torch.backends.cudnn.benchmark = False
                torch.backends.cudnn.deterministic = True
                torch.use_deterministic_algorithms(True)
                os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
            seed = self.seed + dist.get_rank() * 10000
            os.environ['PYTHONHASHSEED'] = str(seed)
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

    def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]:  # for random augmentation
        if self.seed is None:
            return None
        g = torch.Generator()
        g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
        return g

    def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
        d = (OrderedDict if key_ordered else dict)()
        for k in self.class_variables.keys():
            if k not in {'device'}:  # these are not serializable
                d[k] = getattr(self, k)
        return d

    def load_state_dict(self, state_dict):
        for k, v in state_dict.items():
            try:
                setattr(self, k, v)
            except Exception as e:
                print(f'k={k}, v={v}')
                raise e

    @staticmethod
    def set_tf32(tf32: bool):
        if torch.cuda.is_available():
            torch.backends.cudnn.allow_tf32 = bool(tf32)
            torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
            if hasattr(torch, 'set_float32_matmul_precision'):
                torch.set_float32_matmul_precision('high' if tf32 else 'highest')
                print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
            print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
            print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')

    def __str__(self):
        s = []
        for k in self.class_variables.keys():
            if k not in {'device', 'dbg_ks_fp'}:  # these are not serializable
                s.append(f'  {k:20s}: {getattr(self, k)}')
        s = '\n'.join(s)
        return f'{{\n{s}\n}}\n'


def init_dist_and_get_args():
    for i in range(len(sys.argv)):
        if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
            del sys.argv[i]
            break

    args = Args(explicit_bool=True).parse_args(known_only=True)
    # warn args.extra_args
    if len(args.extra_args) > 0:
        print(f'======================================================================================')
        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
        print(f'======================================================================================\n\n')

    # init torch distributed
    os.makedirs(args.output_dir, exist_ok=True)
    dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30)

    # set env
    args.set_tf32(args.tf32)
    args.seed_everything()
    args.device = dist.get_device()

    # update args
    if args.local_bs == 0:
        args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size()))
    args.global_bs = args.local_bs * dist.get_world_size()
    if args.fp16 or args.bf16:
        args.dtype = torch.float16 if args.fp16 else torch.bfloat16

    return args
