import argparse
from pathlib import Path
import time
from glob import glob
import os
import shutil
from torch.backends import cudnn
import random
import numpy as np

import torch
import wandb  # Quit early if user doesn't have wandb installed.
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F

#
from models import distributed_utils
from models.loader import TextImageDataset
from models.pretrain_dataset import AllDataset

# libraries needed for webdataset support
from torchvision import transforms as T
from PIL import Image
from io import BytesIO

#from clip import clip
import copy

# transformer model
from models import CoTransformer
from transformers import (
    DataCollatorForLanguageModeling,
    DataCollatorForWholeWordMask,
    BertTokenizer,
)
from transformers.optimization import AdamW

# argument parsing

# argument parsing

parser = argparse.ArgumentParser()

group = parser.add_mutually_exclusive_group(required=False)

group.add_argument('--model_load_path', default='', type=str,
                   help='path to your trained Transformer')

group.add_argument('--transformer_path', type=str,
                   help='path to your partially trained Transformer')

parser.add_argument('--data_dir', type=str, required=True,
                    help='path to your folder of frames')

parser.add_argument('--json_file', type=str, required=True,
                    help='path to your json file of captions and shots')

parser.add_argument('--transformer_output_file_name', type=str, default = "checkpoints/transformer_pretrained",
                    help='output_file_name')
                    
parser.add_argument('--seed', type=int, default=1024, help='Seed for random number')

parser.add_argument('--wandb_name', default='dalle_train_transformer',
                    help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')

parser.add_argument('--wandb_entity', default=None,
                    help='(optional) Name of W&B team/entity to log to.')

parser = distributed_utils.wrap_arg_parser(parser)

train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--phase', default = 'train', type = str, help = 'train or test')

train_group.add_argument("--seq_len", type=int, default=10, help="Max length of sequence")

train_group.add_argument('--epochs', default = 50, type = int, help = 'Number of epochs')

train_group.add_argument('--save_every_n_steps', default = 1000, type = int, help = 'Save a checkpoint every n steps')

train_group.add_argument('--keep_n_checkpoints', default = None, type = int, help = '(Careful) Deletes old deepspeed checkpoints if there are more than n')

train_group.add_argument('--batch_size', default = 16, type = int, help = 'Batch size')

train_group.add_argument('--ga_steps', default = 1, type = int, help = 'Number of steps to accumulate gradients across per each iteration. DeepSpeed only.')

train_group.add_argument('--learning_rate', default = 5e-5, type = float, help = 'Learning rate')

train_group.add_argument('--loss_weight', default = 0.1, type = float, help = 'weight to balance the loss')

train_group.add_argument('--clip_grad_norm', default = 0.5, type = float, help = 'Clip gradient norm')

train_group.add_argument('--lr_decay', dest = 'lr_decay', action = 'store_true')

model_group = parser.add_argument_group('Model settings')

model_group.add_argument('--clip_model', default = "ViT-B/32", type = str, help = 'Name of CLIP')

model_group.add_argument('--hidden_size', default = 512, type = int, help = 'Model dimension')

model_group.add_argument('--image_size', default = 256, type = int, help = 'Size of image')

model_group.add_argument('--mlp_ratio', default = 4, type = int, help = 'mlp_ratio')

model_group.add_argument('--drop_rate', default = 0.1, type = float, help = 'drop_rate')

model_group.add_argument('--num_heads', default = 8, type = int, help = 'Model number of heads')

model_group.add_argument('--num_layers', default = 2, type = int, help = 'Model depth')

model_group.add_argument('--topk', default = 4, type = int)

model_group.add_argument("--nce_T", type=float, default=0.05, help="Temperature for nec Loss")

model_group.add_argument("--ratio", type=float, default=0.15, help="Ratio for random mask")

args = parser.parse_args()

# helpers
def check_length(sequence, mask, seq_len) : 
    assert isinstance(sequence,torch.Tensor)
    
    if sequence.shape[0] >= seq_len : 
        sequence = sequence[:seq_len]
        mask = mask[:seq_len]
    while sequence.shape[0] < seq_len : 
        empty = torch.zeros_like(sequence[0]).unsqueeze(dim = 0)
        empty_mask = torch.zeros_like(mask[0]).unsqueeze(dim = 0)
        sequence = torch.cat([sequence, empty], dim = 0)
        mask = torch.cat([mask, empty_mask], dim = 0)
    
    return sequence, mask

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

def exists(val):
    return val is not None

def get_trainable_params(model):
    return [params for params in model.parameters() if params.requires_grad]

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

def cp_path_to_dir(cp_path, tag):
    """Convert a checkpoint path to a directory with `tag` inserted.
    If `cp_path` is already a directory, return it unchanged.
    """
    if not isinstance(cp_path, Path):
        cp_path = Path(cp_path)
    if cp_path.is_dir():
        return cp_path
    path_sans_extension = cp_path.parent / cp_path.stem
    cp_dir = Path(f'{path_sans_extension}-{tag}-cp')
    return cp_dir

# constants
TRANSFORMER_OUTPUT_FILE_NAME = args.transformer_output_file_name + ".pt"

TRANSFORMER_PATH = args.transformer_path
RESUME = exists(TRANSFORMER_PATH)

EPOCHS = args.epochs
BATCH_SIZE = args.batch_size

LEARNING_RATE = args.learning_rate
GRAD_CLIP_NORM = args.clip_grad_norm
LR_DECAY = args.lr_decay
SAVE_EVERY_N_STEPS = args.save_every_n_steps
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

DEPTH = args.num_layers
HEADS = args.num_heads

assert Path(args.data_dir).exists(), f'The path {args.data_dir} was not found.'

# initialize distributed backend

distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

# reconstitute vae
if RESUME:
    transformer_path = Path(TRANSFORMER_PATH)
    
    
    assert transformer_path.exists(), 'TRANSFORMER model file does not exist'
    
    loaded_obj = torch.load(str(transformer_path), map_location='cpu')

    transformer_params, weights = loaded_obj['hparams'], loaded_obj['weights']
    opt_state = loaded_obj.get('opt_state')
    scheduler_state = loaded_obj.get('scheduler_state')

    transformer_params = dict(
        **transformer_params
    )
    resume_epoch = loaded_obj.get('epoch', 0)
    print("load partially model")
else:
    transformer_params = dict(
        args = args
    )
    resume_epoch = 0


# create dataset and dataloader

is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)

ds = AllDataset(
    args=args
)
assert len(ds) > 0, 'dataset is empty'

if distr_backend.is_root_worker():
    print(f'{len(ds)} image-text pairs found for training')

if not is_shuffle:
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds,
        num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank()
    )
else:
    data_sampler = None

# Regular DataLoader for image-text-folder datasets
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)


# initialize DALL-E and CLIP

transformer = CoTransformer(**transformer_params) #DALLE(vae=vae, CLIP=None, clip_transform=clip_transform, **dalle_params)
transformer = transformer.cuda()

if RESUME:
    transformer.load_state_dict(weights)

# optimizer
lr = LEARNING_RATE
wd = 0.01

no_decay = [
    "bias",
    "LayerNorm.bias",
    "LayerNorm.weight",
    "norm.bias",
    "norm.weight",
    "norm1.bias",
    "norm1.weight",
    "norm2.bias",
    "norm2.weight",
]
head_names = ["vqa_classifier", "nlvr2_classifier"]
lr_mult = 1
end_lr = 0
decay_power = 1

optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in transformer.named_parameters()
                if not any(nd in n for nd in no_decay)
                and not any(bb in n for bb in head_names)
            ],
            "weight_decay": wd,
            "lr": lr,
        },
        {
            "params": [
                p
                for n, p in transformer.named_parameters()
                if any(nd in n for nd in no_decay)
                and not any(bb in n for bb in head_names)
            ],
            "weight_decay": 0.0,
            "lr": lr,
        },
        {
            "params": [
                p
                for n, p in transformer.named_parameters()
                if not any(nd in n for nd in no_decay)
                and any(bb in n for bb in head_names)
            ],
            "weight_decay": wd,
            "lr": lr * lr_mult,
        },
        {
            "params": [
                p
                for n, p in transformer.named_parameters()
                if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
            ],
            "weight_decay": 0.0,
            "lr": lr * lr_mult,
        },
    ]

opt = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, eps=1e-8, betas=(0.9, 0.98))

if distr_backend.is_root_worker():
    # experiment tracker

    model_config = dict(
        depth=DEPTH,
        heads=HEADS
    )

    run = wandb.init(
        project=args.wandb_name,
        entity=args.wandb_entity,
        resume=False,
        config=model_config,
    )

# distribute
if distr_backend.is_root_worker():
    print("number of total gpus : ", distr_backend.get_world_size()) #torch.distributed.get_world_size()

(distr_model, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute(
    args=args,
    model=transformer,
    optimizer=opt,
    model_parameters=get_trainable_params(transformer),
    training_data= dl,
)

print("load all model")

def save_model(path, epoch=0):
    save_obj = {
        'hparams': transformer_params,
        'epoch': epoch,
    }

    if not distr_backend.is_root_worker():
        return

    save_obj = {
        **save_obj,
        'weights': transformer.state_dict(),
        'opt_state': opt.state_dict(),
    }
    save_obj['scheduler_state'] = (distr_scheduler.state_dict() if distr_scheduler else None)
    torch.save(save_obj, path)

# training

save_model(TRANSFORMER_OUTPUT_FILE_NAME, epoch=resume_epoch)

for epoch in range(resume_epoch, EPOCHS):
    if data_sampler:
        data_sampler.set_epoch(epoch)
    for i, (shots, false_shots, texts, shot_mask, false_mask, text_mask) in enumerate(distr_dl):
        if i % 10 == 0 and distr_backend.is_root_worker():
            t = time.time()
        
        # create text-image pairs for itm task
        assert shots.shape[0] == false_shots.shape[0] and shots.shape[0] > 2
        
        length = shots.shape[0]
        
        pos_len = length // 2
        neg_len = length - pos_len
        
        # itm
        itm_labels = torch.cat([torch.ones(pos_len), torch.zeros(neg_len)]).cuda()
        
        itm_labels = itm_labels[torch.randperm(itm_labels.size(0))]
        
        itm_shots = []
        itm_mask = []
        for idx in range(length) : 
            if itm_labels[idx] == 1 : 
                input_shots = shots[idx]
                input_mask = shot_mask[idx]
                
                l = input_mask.sum(dim = -1).int().item()
                
                m_length = random.randint(1, l)
                m_range = torch.arange(input_mask.shape[-1] - m_length) + m_length

                input_mask[m_range] = 0
                
            elif itm_labels[idx] == 0:
                # random select position and length to replace
                l = shot_mask[idx].sum(dim = -1).int().item()
                begin_pos = random.randint(0, l)
                
                if l == args.seq_len and l == begin_pos : 
                    begin_pos = begin_pos - random.randint(1, l)
                
                
                l_f = false_mask[idx].sum(dim = -1).int().item()
                if l_f == 1 : 
                    f_pos = 0
                else :
                    f_pos = random.randint(0, l_f - 1)
                if l_f - f_pos <= 1: 
                    f_length = 1
                else : 
                    f_l = min(l_f - f_pos, int(l_f / 2))
                    f_length = random.randint(1, f_l)
                
                input_shots = torch.cat([shots[idx][:begin_pos], false_shots[idx][f_pos: f_pos + f_length], shots[idx][begin_pos:]], dim = 0)
                input_mask = torch.cat([shot_mask[idx][:begin_pos], false_mask[idx][f_pos: f_pos + f_length], shot_mask[idx][begin_pos:]], dim = 0)
                
                input_shots, input_mask = check_length(input_shots, input_mask, args.seq_len)
            else : 
                print("label error in itm")
                exit()
            
            itm_shots.append(input_shots)
            itm_mask.append(input_mask)
        
        itm_shots = torch.stack(itm_shots).cuda()
        itm_mask = torch.stack(itm_mask).cuda()
        
        # shuflle
        shuffle_shots = []
        s_mask = []
        s_labels = []
        
        for idx in range(length) : 
            input_mask = shot_mask[idx]
            l = input_mask.sum(dim = -1).int().item()

            s_range = torch.randperm(l)
            l_range = torch.arange(input_mask.shape[-1] - l, dtype=torch.long) + l#.long()
            n_range = torch.cat([s_range, l_range])

            input_shots = shots[idx][n_range]
            
            label = n_range
            
            # random mask
            label[~input_mask.bool()] = -100

            shuffle_shots.append(input_shots)
            s_mask.append(input_mask)
            s_labels.append(label)
        
        shuffle_shots = torch.stack(shuffle_shots).cuda()
        s_mask = torch.stack(s_mask).cuda()
        s_labels = torch.stack(s_labels).cuda()
        
        texts = texts.cuda()
        text_mask = text_mask.cuda()
        
        itm_loss, shuffle_loss = distr_model(itm_shots, shuffle_shots, texts, itm_labels, s_labels, text_mask, itm_mask, s_mask)
        
        loss = itm_loss + args.loss_weight * shuffle_loss


        loss.backward()
        #clip_grad_norm_(distr_model.parameters(), GRAD_CLIP_NORM)
        distr_opt.step()
        distr_opt.zero_grad()

        # Collective loss, averaged
        avg_loss = distr_backend.average_all(loss)
        avg_shuffle_loss = distr_backend.average_all(shuffle_loss)
        avg_itm_loss = distr_backend.average_all(itm_loss)
        
        log = {}

        if i % 10 == 0 and distr_backend.is_root_worker():
            print(epoch, i, f'loss - {avg_loss.item()}')
            
            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': avg_loss.item(),
                'itm_loss': avg_itm_loss.item(),
                'shuffle_loss': avg_shuffle_loss.item(),
                'lr': distr_opt.state_dict()['param_groups'][0]['lr']
            }

        if i % 10 == 9 and distr_backend.is_root_worker():
            sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
            log["sample_per_sec"] = sample_per_sec
            print(epoch, i, f'sample_per_sec - {sample_per_sec}')

        if distr_backend.is_root_worker():
            wandb.log(log) 
    
    
    epoch_name = TRANSFORMER_OUTPUT_FILE_NAME.split('.')[0] + '_' + str(epoch) + '.pt'
    save_model(epoch_name, epoch=epoch)
    
    if distr_backend.is_root_worker():
        # save trained model to wandb as an artifact every epoch's end

        model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
        model_artifact.add_file(TRANSFORMER_OUTPUT_FILE_NAME)
        run.log_artifact(model_artifact)

save_model(TRANSFORMER_OUTPUT_FILE_NAME, epoch=epoch)
if distr_backend.is_root_worker():
    wandb.save(TRANSFORMER_OUTPUT_FILE_NAME)
    model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
    model_artifact.add_file(TRANSFORMER_OUTPUT_FILE_NAME)
    run.log_artifact(model_artifact)

    wandb.finish()
