import argparse
import sys
import os
from math import ceil

import json
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import AutoImageProcessor

from tqdm import tqdm

sys.path.insert(0, "../")
from inference.mm_utils import process_images
from inference.builder import load_pretrained_model

from model import VQVAE
from scheduler import CycleScheduler
import distributed as dist
from time import sleep

##########################################################################################################
##########################################################################################################
##########################################################################################################
def log_metrics(mse_list, cos_list, lr_list, epoch_list, i_list, iva_list, mse, cos, lr, epoch, i, iva):
    mse_list.append(mse)
    cos_list.append(cos)
    lr_list.append(lr)
    epoch_list.append(epoch)
    i_list.append(i)
    iva_list.append(iva)


def save_metrics(metric_dict, values_save_dir, plots_save_dir, save_name_preffix, block_size=6):
    ncols = 1
    nrows = len(metric_dict)
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size * 4, nrows * block_size))

    for i, (metric_name, metric_list) in enumerate(metric_dict.items()):
        name = save_name_preffix + f'_{metric_name}'

        values_dir = os.path.join(values_save_dir, 'values')
        os.makedirs(values_dir, mode=0o777, exist_ok=True)
        np.save(os.path.join(values_dir, name + '.npy'), np.asarray(metric_list))

        plot_dir = os.path.join(plots_save_dir, 'plots')
        os.makedirs(plot_dir, mode=0o777, exist_ok=True)
        axes[i].plot(metric_list, label=metric_name)

        axes[i].legend()
    plt.savefig(os.path.join(plot_dir, name + '.png'))
    plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def calculate_encoding_mask(input, encoding_inds):
    shape = input.shape[1], input.shape[2]
    mask = torch.zeros(shape).view(-1)
    mask[encoding_inds] = 1
    mask = mask.view(shape)
    return mask


def draw_in_out_cos(images,
                    reconstructed_images,
                    in_out_cos_list,
                    inputs,
                    encoding_inds_list,
                    examples=10, block_size=6, save_dir=None, save_name=None):
    examples = min(examples, len(images))
    ncols = 5
    nrows = examples
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size, nrows * block_size))

    for i in range(examples):
        ###############################################################################
        image = images[i]
        axes[i, 0].imshow(image)

        reconstructed_image = reconstructed_images[i].cpu().permute(1, 2, 0)
        axes[i, 1].imshow((reconstructed_image - reconstructed_image.min()) / (reconstructed_image.max() - reconstructed_image.min()))

        ###############################################################################
        input, encoding_inds = inputs[i].cpu(), encoding_inds_list[i].cpu()
        encoding_mask = calculate_encoding_mask(input, encoding_inds)
        axes[i, 2].imshow(encoding_mask)

        ###############################################################################
        in_out_cos = in_out_cos_list[i].cpu()
        axes[i, 3].imshow(in_out_cos)

        ###############################################################################
        flatten_in_out_cos = in_out_cos.view(-1)
        axes[i, 4].plot(flatten_in_out_cos, alpha=0.5, label='in_out_cos')
        axes[i, 4].legend()

        if i == 0:
            axes[i, 0].set_title('image')
            axes[i, 1].set_title(f'reconstructed_image')
            axes[i, 2].set_title(f'gumbel mask')
            axes[i, 3].set_title(f'cosine similarity between\ninput and reconstruction')
            axes[i, 4].set_title(f'cosine similarity between\ninput and reconstruction\n(flatten version)')

        axes[i, 0].set_axis_off()
        axes[i, 1].set_axis_off()
        axes[i, 2].set_axis_off()
        axes[i, 3].set_axis_off()

    if save_dir and save_name:
        plt.savefig(os.path.join(save_dir, save_name))
        plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def draw_input_norm(inputs, outputs, examples=10, block_size=6, save_dir=None, save_name=None):
    ncols = 1
    examples = min(examples, len(inputs))
    nrows = examples
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * block_size * 4, nrows * block_size))

    for i in range(examples):
        input_norm = inputs[i].cpu().norm(dim=0).view(-1)
        output_norm = outputs[i].cpu().norm(dim=0).view(-1)

        axes[i].plot(input_norm, alpha=0.5, label='input')
        axes[i].plot(output_norm, alpha=0.5, label='output')
        axes[i].legend()

        if i == 0:
            axes[i].set_title(f'norms of input and reconstructed features')

    if save_dir and save_name:
        plt.savefig(os.path.join(save_dir, save_name))
        plt.close()


##########################################################################################################
##########################################################################################################
##########################################################################################################
def val(model, eval_dataset, batch_size, device):
    model.eval()

    images = []
    interpolated_images = []
    sample = []
    numbers = [0, 1, 100, 200,
               5100, 5200, 5300,
               6100, 6200, 6300]
    for j in range(len(numbers)):
        if numbers[j] >= len(eval_dataset):
            break

        im, inter_im, fe = eval_dataset[numbers[j]]
        images.append(im)
        interpolated_images.append(inter_im.to(device)[None])
        sample.append(fe.to(device)[None])
    interpolated_images = torch.cat(interpolated_images, dim=0)
    sample = torch.cat(sample, dim=0)
    sample /= sample.norm(dim=1, keepdim=True)

    with torch.no_grad():
        out, out_sample, input, check_dict = model(sample, interpolated_images)
        loss = model.loss_function(interpolated_images, sample, out, out_sample, check_dict['input_vector_amount'])

    images = interpolated_images.permute(0, 2, 3, 1).cpu()
    images = (images - images.min()) / (images.max() - images.min())
    return images, interpolated_images, out, input, check_dict, loss


def train(epoch, stage, loader, model, optimizer, scheduler, device, eval_dataset, args, train_tuple, val_tuple, save_delta, batch_size):
    if dist.is_primary():
        loader = tqdm(loader)

    train_mse_list, train_cos_list, train_lr_list, train_epoch_list, train_i_list, train_iva_list = train_tuple
    val_mse_list, val_cos_list, val_lr_list, val_epoch_list, val_i_list, val_iva_list = val_tuple
    for i, (im, fe) in enumerate(loader):
        model.zero_grad()
        fe = fe.to(device)
        fe /= fe.norm(dim=1, keepdim=True)
        im = im.to(device)

        out, out_fe, input, check_dict = model(fe, im)
        loss = model.loss_function(im, fe, out, out_fe, check_dict['input_vector_amount'])
        loss['loss'].backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        #########################################################
        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]
            loader.set_description(
                (
                    f"epoch: {epoch + 1}; "
                    f"msei: {loss['msei'].item():.5f}; "
                    f"msef: {loss['msef'].item():.5f}; "
                    f"cos: {loss['cos'].item():.5f}; "
                    f"input_vector_amount: {loss['iva'].item():.5f}; "
                    f"lr: {lr:.7f}"
                )
            )


            iva_factor, epoches = args.input_vector_amounts[stage], args.epoch[stage]
            encoder_layers, n_crops = args.encoder_layers, args.n_crops
            sample_save_dir = make_save_path(args.sample_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops)


            if i % save_delta == 0:
                log_metrics(train_mse_list, train_cos_list, train_lr_list, train_epoch_list, train_i_list, train_iva_list,
                            loss['msei'].item(), loss['cos'].item(), lr, epoch, i, check_dict['input_vector_amount'].item())

                train_values_save_dir=os.path.join(sample_save_dir, 'metrics')
                train_plots_save_dir=os.path.join(sample_save_dir, 'metrics')
                os.makedirs(train_values_save_dir, mode=0o777, exist_ok=True)
                os.makedirs(train_plots_save_dir, mode=0o777, exist_ok=True)
                save_metrics(
                    {
                        'mse': train_mse_list,
                        'cos': train_cos_list,
                        'lr': train_lr_list,
                        'epoch': train_epoch_list,
                        'i': train_i_list,
                        'iva': train_iva_list
                    },
                    values_save_dir=train_values_save_dir,
                    plots_save_dir=train_plots_save_dir,
                    save_name_preffix=f'train')


            if i % save_delta == 0:
                model.eval()

                val_images, interpolated_images, val_out, val_input, val_check_dict, val_loss = val(model, eval_dataset, fe.shape[0], device)
                log_metrics(val_mse_list, val_cos_list, val_lr_list, val_epoch_list, val_i_list, val_iva_list,
                            val_loss['msei'].item(), val_loss['cos'].item(), lr, epoch, i, val_check_dict['input_vector_amount'].item())

                val_values_save_dir=os.path.join(sample_save_dir, 'metrics')
                val_plots_save_dir=os.path.join(sample_save_dir, 'metrics')
                os.makedirs(val_values_save_dir, mode=0o777, exist_ok=True)
                os.makedirs(val_plots_save_dir, mode=0o777, exist_ok=True)
                save_metrics(
                    {
                        'mse': val_mse_list,
                        'cos': val_cos_list,
                        'lr': val_lr_list,
                        'epoch': val_epoch_list,
                        'i': val_i_list,
                        'iva': val_iva_list
                    },
                    values_save_dir=val_values_save_dir,
                    plots_save_dir=val_plots_save_dir,
                    save_name_preffix=f'val')

                os.makedirs(sample_save_dir, mode=0o777, exist_ok=True)

                in_out_cos_save_name = f'{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}_in-out-cos.png'
                draw_in_out_cos(val_images, val_check_dict['transposed_input'], val_check_dict['in_out_cos'], val_input, val_check_dict['encoding_inds'],
                                save_dir=sample_save_dir, save_name=in_out_cos_save_name)

                draw_input_norm_name = f'{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}_input-reconstruction-norm.png'
                draw_input_norm(interpolated_images, val_out, save_dir=sample_save_dir, save_name=draw_input_norm_name)

                model.train()

###########################################################################################################
###########################################################################################################
###########################################################################################################
###########################################################################################################
###########################################################################################################
class FeaturesDataset(Dataset):
    def __init__(self, features_json_path, mode='train', side_size=None, image_processor=None, cfg=None):
        with open(features_json_path, 'r') as map_file:
            self.map = list(json.load(map_file).items())
        self.mode = mode

        self.side_size = side_size
        self.image_processor = image_processor
        self.cfg = cfg

    def __len__(self):
        return len(self.map)

    def __getitem__(self, idx):
        feature_path, (image_path, number) = self.map[idx]
        feature = torch.load(feature_path, map_location='cpu').to(torch.float32)

        H = W = int(feature.shape[0]**0.5)
        feature = feature.reshape(H, W, -1)

        image = Image.open(image_path).convert('RGB')
        processed_image = process_images([image], self.image_processor, self.cfg)[0, number]
        if processed_image.shape[-1] != self.side_size:
            interpolated_image = F.interpolate(processed_image[None],
                                               size=(self.side_size, self.side_size),
                                               mode='bilinear', align_corners=False)[0]
        else:
            interpolated_image = processed_image
        if self.mode == 'train':
            return interpolated_image, feature.permute(2, 0, 1) # tensor[3 x 224 x 224], tensor[768 x 14 x 14]
        else:
            return np.asarray(image), interpolated_image, feature.permute(2, 0, 1) # array(427, 640, 3), tensor[3 x 224 x 224], tensor[768 x 14 x 14]

def calculat_num_parameters(model):
    params = 0
    for n, p in model.named_parameters():
        print(f'{n}: {np.prod(p.shape)}')
        params += np.prod(p.shape)
    print(f'NUM PARAMETERS: {params}')


def make_save_path(save_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops):
    save_path = os.path.join(save_path,
                             f'stage-{stage}_iva-{iva_factor}_epoches-{epoches}_'
                             f'encoder-layers-{encoder_layers}_batch_size-{batch_size}_n_crops-{n_crops}')
    return save_path


def main(args):
    device = args.device

    args.distributed = dist.get_world_size() > 1
    vision_tower_name = args.vision_tower_name
    cache_dir = args.cache_dir
    side_size = args.side_size

    _, model, image_processor, _ = load_pretrained_model(vision_tower_name, cache_dir)
    cfg = model.config
    del model

    batch_size = args.batch_size
    dataset = FeaturesDataset(args.features_json_path, mode='train', image_processor=image_processor, side_size=side_size, cfg=cfg)
    eval_dataset = FeaturesDataset(args.val_features_json_path, mode='val', image_processor=image_processor, side_size=side_size, cfg=cfg)
    sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
    loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16)

    for stage in range(len(args.epoch)):
        iva_factor, epoches = args.input_vector_amounts[stage], args.epoch[stage]
        encoder_layers, n_crops = args.encoder_layers, args.n_crops
        save_path = make_save_path(args.save_path, stage, iva_factor, epoches, encoder_layers, batch_size, n_crops)

        current_w_path = os.path.join(save_path, f'vqvae-{str(epoches).zfill(3)}.pt')
        if os.path.exists(current_w_path):
            continue

        model = VQVAE(iva_factor, encoder_layers, vision_tower_name, cache_dir).to(device)
        if stage > 0:
            prev_iva_factor, prev_epoches = args.input_vector_amounts[stage - 1], args.epoch[stage - 1]
            prev_save_path = make_save_path(args.save_path, stage - 1, prev_iva_factor, prev_epoches, encoder_layers, batch_size, n_crops)

            w_path = os.path.join(prev_save_path, f'vqvae-{str(prev_epoches).zfill(3)}.pt')
            print('--------------->: ', w_path)
            while not os.path.exists(w_path):
                sleep(10)
                print(f'Not exists: {w_path}')
            model.load_state_dict(torch.load(w_path, weights_only=True), strict=False)
        calculat_num_parameters(model)

        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        scheduler = None
        if args.sched == "cycle":
            scheduler = CycleScheduler(
                optimizer,
                args.lr,
                n_iter=len(loader) * args.epoch[stage],
                momentum=None,
                warmup_proportion=0.05,
            )


        train_tuple = [], [], [], [], [], []
        val_tuple = [], [], [], [], [], []
        for i in range(epoches):
            train(i, stage, loader, model, optimizer, scheduler, device, eval_dataset, args, train_tuple, val_tuple, args.save_delta, batch_size)

            if i % 1 == 0:
                os.makedirs(save_path, mode=0o777, exist_ok=True)
                if dist.is_primary():
                    torch.save(model.state_dict(), os.path.join(save_path, f'vqvae-{str(i + 1).zfill(3)}.pt'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_gpu", type=int, default=6) #8

    port = (
        2 ** 15
        + 2 ** 14
        + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    )
    parser.add_argument("--dist_url", default=f"tcp://127.0.0.1:{port}")

    # training parameters
    parser.add_argument("--vision_tower_name", type=str, default='liuhaotian/llava-v1.6-vicuna-7b')
    # parser.add_argument("--vision_tower_name", type=str, default='Lin-Chen/open-llava-next-vicuna-7b')
    parser.add_argument("--epoch", type=list, default=[60, 20, 20])
    parser.add_argument("--input_vector_amounts", type=list, default=[0.6, 0.6, 0.6])
    parser.add_argument("--encoder_layers", type=int, default=4)
    parser.add_argument("--side_size", type=int, default=384)
    parser.add_argument("--batch_size", type=int, default=32) #32
    parser.add_argument("--lr", type=float, default=1e-5) #1e-5 #1e-4 for all patches, 3e-4 default


    parser.add_argument("--n_crops", type=int, default=5)
    parser.add_argument("--sched", type=str, default='cycle')

    # data parameters
    parser.add_argument("--device", type=str, default='cuda')

    # cache dir for image_processor
    parser.add_argument("--cache_dir", type=str, default=None)


    # pathes for dataset
    parser.add_argument("--features_json_path", type=str,
                        default='../../../data/llava-v1.6-vicuna-7b_mlp/map.json')
    parser.add_argument("--val_features_json_path", type=str,
                        default='../../../data/llava-v1.6-vicuna-7b_mlp/map_val.json')

    # pathes for logging
    parser.add_argument("--save_path", type=str,
                        default='./checkpoints')
    parser.add_argument("--sample_path", type=str,
                        default='./samples')

    parser.add_argument("--save_delta", type=int, default=500)

    args = parser.parse_args()

    print(args)

    dist.launch(main, args.n_gpu, 1, 0, args.dist_url, args=(args,))
